{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BojqmYOsPk0A"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "FtMmJ-pvPfNl"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IyATsAVlPz1W"
      },
      "source": [
        "# Finetuning Gemma Using LitGPT\n",
        "\n",
        "[Gemma](https://ai.google.dev/gemma) is a family of lightweight, state-of-the-art open-source language models from Google. Built from the same research and technology used to create the Gemini models, Gemma models are text-to-text, decoder-only large language models (LLMs), available in English, with open weights, pre-trained variants, and instruction-tuned variants.\n",
        "Gemma models are well-suited for various text-generation tasks, including question-answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as a laptop, desktop, or your cloud infrastructure, democratizing access to state-of-the-art AI models and helping foster innovation for everyone.\n",
        "\n",
        "[LitGPT](https://github.com/Lightning-AI/litgpt) is a framework for working with Large Language models (LLMs). It goes beyond just running LLMs. LitGPT provides a toolkit for the entire LLM lifecycle, including pre-training new models, fine-tuning existing ones for specific tasks, evaluating their performance, and deploying them for real-world use.\n",
        "\n",
        "This notebook guides you through fine-tuning, prompting, and deploying Gemma2 using LitGPT on Google Colab. You'll also upload your fine-tuned model to the Hugging Face Hub.\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Finetune_with_LitGPT.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nFgaq_--Qg-O"
      },
      "source": [
        "## Setup\n",
        "\n",
        "### Select the Colab runtime\n",
        "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run and fine-tune the Gemma model. In this case, you can use a T4 GPU with High RAM:\n",
        "\n",
        "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n",
        "2. Select **Change runtime type**.\n",
        "3. Under **Hardware accelerator**, select **T4 GPU**. Toggle the High RAM option on.\n",
        "\n",
        "### Setup Hugging Face\n",
        "\n",
        "**Before you dive into the tutorial, let's get you set up with Hugging face:**\n",
        "\n",
        "#### Hugging Face setup\n",
        "\n",
        "1. **Hugging Face Account:**  If you don't already have one, you can create a free Hugging Face account by clicking [here](https://huggingface.co/join).\n",
        "\n",
        "2. **Hugging Face Token:**  Generate a Hugging Face access (with `write` permission) token by clicking [here](https://huggingface.co/settings/tokens). You'll need this token later in the tutorial.\n",
        "\n",
        "**Once you've completed these steps, you're ready to move on to the next section where you'll set up environment variables in your Colab environment.**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bLEUJYZ8QmGz"
      },
      "source": [
        "### Configure your HF token\n",
        "\n",
        "Add your Hugging Face token to the Colab Secrets manager to securely store it.\n",
        "\n",
        "1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n",
        "2. Create a new secret with the name `HF_TOKEN`.\n",
        "3. Copy/paste your HF token key into the Value input box of `HF_TOKEN`.\n",
        "4. Toggle the button on the left to allow notebook access to the secret."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "hK-qUiQGQbe5"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "\n",
        "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "# vars as appropriate for your system.\n",
        "os.environ[\"HF_TOKEN\"] = userdata.get(\"HF_TOKEN\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5DrB-trDQsgE"
      },
      "source": [
        "### Install dependencies\n",
        "\n",
        "First, you must install the python package for LitGPT."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "LLjxxhk2Qrf_"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting litgpt==0.5.3 (from litgpt[all]==0.5.3)\n",
            "  Downloading litgpt-0.5.3-py3-none-any.whl.metadata (41 kB)\n",
            "\u001b[?25l     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/41.9 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.9/41.9 kB\u001b[0m \u001b[31m2.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting torch<=2.4.1,>=2.2.0 (from litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading torch-2.4.1-cp310-cp310-manylinux1_x86_64.whl.metadata (26 kB)\n",
            "Requirement already satisfied: numpy<2.0 in /usr/local/lib/python3.10/dist-packages (from litgpt==0.5.3->litgpt[all]==0.5.3) (1.26.4)\n",
            "Collecting lightning==2.4.0 (from litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading lightning-2.4.0-py3-none-any.whl.metadata (38 kB)\n",
            "Collecting jsonargparse<=4.32.1,>=4.30.1 (from jsonargparse[signatures]<=4.32.1,>=4.30.1->litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading jsonargparse-4.32.1-py3-none-any.whl.metadata (12 kB)\n",
            "Requirement already satisfied: huggingface-hub>=0.23.5 in /usr/local/lib/python3.10/dist-packages (from litgpt==0.5.3->litgpt[all]==0.5.3) (0.26.2)\n",
            "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.10/dist-packages (from litgpt==0.5.3->litgpt[all]==0.5.3) (0.4.5)\n",
            "Requirement already satisfied: tokenizers>=0.15.2 in /usr/local/lib/python3.10/dist-packages (from litgpt==0.5.3->litgpt[all]==0.5.3) (0.20.3)\n",
            "Requirement already satisfied: tqdm>=4.66.0 in /usr/local/lib/python3.10/dist-packages (from litgpt==0.5.3->litgpt[all]==0.5.3) (4.66.6)\n",
            "Collecting bitsandbytes==0.42.0 (from litgpt[all]==0.5.3)\n",
            "  Downloading bitsandbytes-0.42.0-py3-none-any.whl.metadata (9.9 kB)\n",
            "Requirement already satisfied: sentencepiece>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from litgpt[all]==0.5.3) (0.2.0)\n",
            "Requirement already satisfied: requests>=2.31.0 in /usr/local/lib/python3.10/dist-packages (from litgpt[all]==0.5.3) (2.32.3)\n",
            "Collecting litdata==0.2.17 (from litgpt[all]==0.5.3)\n",
            "  Downloading litdata-0.2.17-py3-none-any.whl.metadata (31 kB)\n",
            "Collecting litserve>=0.1.5 (from litgpt[all]==0.5.3)\n",
            "  Downloading litserve-0.2.5-py3-none-any.whl.metadata (16 kB)\n",
            "Collecting zstandard>=0.22.0 (from litgpt[all]==0.5.3)\n",
            "  Downloading zstandard-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)\n",
            "Requirement already satisfied: pandas>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from litgpt[all]==0.5.3) (2.2.2)\n",
            "Requirement already satisfied: pyarrow>=15.0.2 in /usr/local/lib/python3.10/dist-packages (from litgpt[all]==0.5.3) (17.0.0)\n",
            "Requirement already satisfied: tensorboard>=2.14.0 in /usr/local/lib/python3.10/dist-packages (from litgpt[all]==0.5.3) (2.17.1)\n",
            "Collecting torchmetrics>=1.3.1 (from litgpt[all]==0.5.3)\n",
            "  Downloading torchmetrics-1.6.0-py3-none-any.whl.metadata (20 kB)\n",
            "Collecting datasets>=2.18.0 (from litgpt[all]==0.5.3)\n",
            "  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)\n",
            "Requirement already satisfied: transformers>=4.38.0 in /usr/local/lib/python3.10/dist-packages (from litgpt[all]==0.5.3) (4.46.2)\n",
            "Collecting lm-eval>=0.4.2 (from litgpt[all]==0.5.3)\n",
            "  Downloading lm_eval-0.4.5-py3-none-any.whl.metadata (44 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.0/44.0 kB\u001b[0m \u001b[31m3.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting uvloop>=0.2.0 (from litgpt[all]==0.5.3)\n",
            "  Downloading uvloop-0.21.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.9 kB)\n",
            "Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from bitsandbytes==0.42.0->litgpt[all]==0.5.3) (1.13.1)\n",
            "Requirement already satisfied: PyYAML<8.0,>=5.4 in /usr/local/lib/python3.10/dist-packages (from lightning==2.4.0->litgpt==0.5.3->litgpt[all]==0.5.3) (6.0.2)\n",
            "Requirement already satisfied: fsspec<2026.0,>=2022.5.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<2026.0,>=2022.5.0->lightning==2.4.0->litgpt==0.5.3->litgpt[all]==0.5.3) (2024.10.0)\n",
            "Collecting lightning-utilities<2.0,>=0.10.0 (from lightning==2.4.0->litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading lightning_utilities-0.11.9-py3-none-any.whl.metadata (5.2 kB)\n",
            "Requirement already satisfied: packaging<25.0,>=20.0 in /usr/local/lib/python3.10/dist-packages (from lightning==2.4.0->litgpt==0.5.3->litgpt[all]==0.5.3) (24.2)\n",
            "Requirement already satisfied: typing-extensions<6.0,>=4.4.0 in /usr/local/lib/python3.10/dist-packages (from lightning==2.4.0->litgpt==0.5.3->litgpt[all]==0.5.3) (4.12.2)\n",
            "Collecting pytorch-lightning (from lightning==2.4.0->litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from litdata==0.2.17->litgpt[all]==0.5.3) (3.16.1)\n",
            "Collecting boto3 (from litdata==0.2.17->litgpt[all]==0.5.3)\n",
            "  Downloading boto3-1.35.72-py3-none-any.whl.metadata (6.7 kB)\n",
            "Collecting dill<0.3.9,>=0.3.0 (from datasets>=2.18.0->litgpt[all]==0.5.3)\n",
            "  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n",
            "Collecting xxhash (from datasets>=2.18.0->litgpt[all]==0.5.3)\n",
            "  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n",
            "Collecting multiprocess<0.70.17 (from datasets>=2.18.0->litgpt[all]==0.5.3)\n",
            "  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n",
            "Collecting fsspec<2026.0,>=2022.5.0 (from fsspec[http]<2026.0,>=2022.5.0->lightning==2.4.0->litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)\n",
            "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets>=2.18.0->litgpt[all]==0.5.3) (3.11.2)\n",
            "Collecting hf-transfer>=0.1.4 (from huggingface-hub[hf_transfer]>=0.21.0; extra == \"all\"->litgpt[all]==0.5.3)\n",
            "  Downloading hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.7 kB)\n",
            "Requirement already satisfied: docstring-parser>=0.15 in /usr/local/lib/python3.10/dist-packages (from jsonargparse[signatures]<=4.32.1,>=4.30.1->litgpt==0.5.3->litgpt[all]==0.5.3) (0.16)\n",
            "Collecting typeshed-client>=2.1.0 (from jsonargparse[signatures]<=4.32.1,>=4.30.1->litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading typeshed_client-2.7.0-py3-none-any.whl.metadata (7.9 kB)\n",
            "Collecting fastapi>=0.100 (from litserve>=0.1.5->litgpt[all]==0.5.3)\n",
            "  Downloading fastapi-0.115.5-py3-none-any.whl.metadata (27 kB)\n",
            "Requirement already satisfied: httpx in /usr/local/lib/python3.10/dist-packages (from litserve>=0.1.5->litgpt[all]==0.5.3) (0.27.2)\n",
            "Collecting uvicorn>=0.29.0 (from uvicorn[standard]>=0.29.0->litserve>=0.1.5->litgpt[all]==0.5.3)\n",
            "  Downloading uvicorn-0.32.1-py3-none-any.whl.metadata (6.6 kB)\n",
            "Requirement already satisfied: accelerate>=0.26.0 in /usr/local/lib/python3.10/dist-packages (from lm-eval>=0.4.2->litgpt[all]==0.5.3) (1.1.1)\n",
            "Collecting evaluate (from lm-eval>=0.4.2->litgpt[all]==0.5.3)\n",
            "  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)\n",
            "Collecting jsonlines (from lm-eval>=0.4.2->litgpt[all]==0.5.3)\n",
            "  Downloading jsonlines-4.0.0-py3-none-any.whl.metadata (1.6 kB)\n",
            "Requirement already satisfied: numexpr in /usr/local/lib/python3.10/dist-packages (from lm-eval>=0.4.2->litgpt[all]==0.5.3) (2.10.1)\n",
            "Requirement already satisfied: peft>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from lm-eval>=0.4.2->litgpt[all]==0.5.3) (0.13.2)\n",
            "Collecting pybind11>=2.6.2 (from lm-eval>=0.4.2->litgpt[all]==0.5.3)\n",
            "  Downloading pybind11-2.13.6-py3-none-any.whl.metadata (9.5 kB)\n",
            "Collecting pytablewriter (from lm-eval>=0.4.2->litgpt[all]==0.5.3)\n",
            "  Downloading pytablewriter-1.2.0-py3-none-any.whl.metadata (37 kB)\n",
            "Collecting rouge-score>=0.0.4 (from lm-eval>=0.4.2->litgpt[all]==0.5.3)\n",
            "  Downloading rouge_score-0.1.2.tar.gz (17 kB)\n",
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Collecting sacrebleu>=1.5.0 (from lm-eval>=0.4.2->litgpt[all]==0.5.3)\n",
            "  Downloading sacrebleu-2.4.3-py3-none-any.whl.metadata (51 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m51.8/51.8 kB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: scikit-learn>=0.24.1 in /usr/local/lib/python3.10/dist-packages (from lm-eval>=0.4.2->litgpt[all]==0.5.3) (1.5.2)\n",
            "Collecting sqlitedict (from lm-eval>=0.4.2->litgpt[all]==0.5.3)\n",
            "  Downloading sqlitedict-2.1.0.tar.gz (21 kB)\n",
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Collecting tqdm-multiprocess (from lm-eval>=0.4.2->litgpt[all]==0.5.3)\n",
            "  Downloading tqdm_multiprocess-0.0.11-py3-none-any.whl.metadata (5.7 kB)\n",
            "Collecting word2number (from lm-eval>=0.4.2->litgpt[all]==0.5.3)\n",
            "  Downloading word2number-1.1.zip (9.7 kB)\n",
            "  Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "Requirement already satisfied: more-itertools in /usr/local/lib/python3.10/dist-packages (from lm-eval>=0.4.2->litgpt[all]==0.5.3) (10.5.0)\n",
            "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.9.0->litgpt[all]==0.5.3) (2.8.2)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.9.0->litgpt[all]==0.5.3) (2024.2)\n",
            "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas>=1.9.0->litgpt[all]==0.5.3) (2024.2)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->litgpt[all]==0.5.3) (3.4.0)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->litgpt[all]==0.5.3) (3.10)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->litgpt[all]==0.5.3) (2.2.3)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->litgpt[all]==0.5.3) (2024.8.30)\n",
            "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.14.0->litgpt[all]==0.5.3) (1.4.0)\n",
            "Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.14.0->litgpt[all]==0.5.3) (1.68.0)\n",
            "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.14.0->litgpt[all]==0.5.3) (3.7)\n",
            "Requirement already satisfied: protobuf!=4.24.0,>=3.19.6 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.14.0->litgpt[all]==0.5.3) (4.25.5)\n",
            "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.14.0->litgpt[all]==0.5.3) (75.1.0)\n",
            "Requirement already satisfied: six>1.9 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.14.0->litgpt[all]==0.5.3) (1.16.0)\n",
            "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.14.0->litgpt[all]==0.5.3) (0.7.2)\n",
            "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=2.14.0->litgpt[all]==0.5.3) (3.1.3)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch<=2.4.1,>=2.2.0->litgpt==0.5.3->litgpt[all]==0.5.3) (1.13.1)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch<=2.4.1,>=2.2.0->litgpt==0.5.3->litgpt[all]==0.5.3) (3.4.2)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch<=2.4.1,>=2.2.0->litgpt==0.5.3->litgpt[all]==0.5.3) (3.1.4)\n",
            "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch<=2.4.1,>=2.2.0->litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch<=2.4.1,>=2.2.0->litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch<=2.4.1,>=2.2.0->litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
            "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<=2.4.1,>=2.2.0->litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
            "Collecting nvidia-cublas-cu12==12.1.3.1 (from torch<=2.4.1,>=2.2.0->litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-cufft-cu12==11.0.2.54 (from torch<=2.4.1,>=2.2.0->litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-curand-cu12==10.3.2.106 (from torch<=2.4.1,>=2.2.0->litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch<=2.4.1,>=2.2.0->litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
            "Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch<=2.4.1,>=2.2.0->litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n",
            "Collecting nvidia-nccl-cu12==2.20.5 (from torch<=2.4.1,>=2.2.0->litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl.metadata (1.8 kB)\n",
            "Collecting nvidia-nvtx-cu12==12.1.105 (from torch<=2.4.1,>=2.2.0->litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.7 kB)\n",
            "Collecting triton==3.0.0 (from torch<=2.4.1,>=2.2.0->litgpt==0.5.3->litgpt[all]==0.5.3)\n",
            "  Downloading triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.3 kB)\n",
            "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch<=2.4.1,>=2.2.0->litgpt==0.5.3->litgpt[all]==0.5.3) (12.6.77)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.38.0->litgpt[all]==0.5.3) (2024.9.11)\n",
            "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.26.0->lm-eval>=0.4.2->litgpt[all]==0.5.3) (5.9.5)\n",
            "Collecting starlette<0.42.0,>=0.40.0 (from fastapi>=0.100->litserve>=0.1.5->litgpt[all]==0.5.3)\n",
            "  Downloading starlette-0.41.3-py3-none-any.whl.metadata (6.0 kB)\n",
            "Requirement already satisfied: pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,!=2.1.0,<3.0.0,>=1.7.4 in /usr/local/lib/python3.10/dist-packages (from fastapi>=0.100->litserve>=0.1.5->litgpt[all]==0.5.3) (2.9.2)\n",
            "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.18.0->litgpt[all]==0.5.3) (2.4.3)\n",
            "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.18.0->litgpt[all]==0.5.3) (1.3.1)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.18.0->litgpt[all]==0.5.3) (24.2.0)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.18.0->litgpt[all]==0.5.3) (1.5.0)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.18.0->litgpt[all]==0.5.3) (6.1.0)\n",
            "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.18.0->litgpt[all]==0.5.3) (0.2.0)\n",
            "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.18.0->litgpt[all]==0.5.3) (1.17.2)\n",
            "Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.18.0->litgpt[all]==0.5.3) (4.0.3)\n",
            "Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (from rouge-score>=0.0.4->lm-eval>=0.4.2->litgpt[all]==0.5.3) (3.9.1)\n",
            "Collecting portalocker (from sacrebleu>=1.5.0->lm-eval>=0.4.2->litgpt[all]==0.5.3)\n",
            "  Downloading portalocker-3.0.0-py3-none-any.whl.metadata (8.5 kB)\n",
            "Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.10/dist-packages (from sacrebleu>=1.5.0->lm-eval>=0.4.2->litgpt[all]==0.5.3) (0.9.0)\n",
            "Collecting colorama (from sacrebleu>=1.5.0->lm-eval>=0.4.2->litgpt[all]==0.5.3)\n",
            "  Downloading colorama-0.4.6-py2.py3-none-any.whl.metadata (17 kB)\n",
            "Requirement already satisfied: lxml in /usr/local/lib/python3.10/dist-packages (from sacrebleu>=1.5.0->lm-eval>=0.4.2->litgpt[all]==0.5.3) (5.3.0)\n",
            "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.24.1->lm-eval>=0.4.2->litgpt[all]==0.5.3) (1.4.2)\n",
            "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.24.1->lm-eval>=0.4.2->litgpt[all]==0.5.3) (3.5.0)\n",
            "Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from typeshed-client>=2.1.0->jsonargparse[signatures]<=4.32.1,>=4.30.1->litgpt==0.5.3->litgpt[all]==0.5.3) (6.4.5)\n",
            "Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from uvicorn>=0.29.0->uvicorn[standard]>=0.29.0->litserve>=0.1.5->litgpt[all]==0.5.3) (8.1.7)\n",
            "Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn>=0.29.0->uvicorn[standard]>=0.29.0->litserve>=0.1.5->litgpt[all]==0.5.3) (0.14.0)\n",
            "Collecting httptools>=0.6.3 (from uvicorn[standard]>=0.29.0->litserve>=0.1.5->litgpt[all]==0.5.3)\n",
            "  Downloading httptools-0.6.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.6 kB)\n",
            "Collecting python-dotenv>=0.13 (from uvicorn[standard]>=0.29.0->litserve>=0.1.5->litgpt[all]==0.5.3)\n",
            "  Downloading python_dotenv-1.0.1-py3-none-any.whl.metadata (23 kB)\n",
            "Collecting watchfiles>=0.13 (from uvicorn[standard]>=0.29.0->litserve>=0.1.5->litgpt[all]==0.5.3)\n",
            "  Downloading watchfiles-1.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.9 kB)\n",
            "Collecting websockets>=10.4 (from uvicorn[standard]>=0.29.0->litserve>=0.1.5->litgpt[all]==0.5.3)\n",
            "  Downloading websockets-14.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\n",
            "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard>=2.14.0->litgpt[all]==0.5.3) (3.0.2)\n",
            "Collecting botocore<1.36.0,>=1.35.72 (from boto3->litdata==0.2.17->litgpt[all]==0.5.3)\n",
            "  Downloading botocore-1.35.72-py3-none-any.whl.metadata (5.7 kB)\n",
            "Collecting jmespath<2.0.0,>=0.7.1 (from boto3->litdata==0.2.17->litgpt[all]==0.5.3)\n",
            "  Downloading jmespath-1.0.1-py3-none-any.whl.metadata (7.6 kB)\n",
            "Collecting s3transfer<0.11.0,>=0.10.0 (from boto3->litdata==0.2.17->litgpt[all]==0.5.3)\n",
            "  Downloading s3transfer-0.10.4-py3-none-any.whl.metadata (1.7 kB)\n",
            "Requirement already satisfied: anyio in /usr/local/lib/python3.10/dist-packages (from httpx->litserve>=0.1.5->litgpt[all]==0.5.3) (3.7.1)\n",
            "Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.10/dist-packages (from httpx->litserve>=0.1.5->litgpt[all]==0.5.3) (1.0.7)\n",
            "Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from httpx->litserve>=0.1.5->litgpt[all]==0.5.3) (1.3.1)\n",
            "Collecting DataProperty<2,>=1.0.1 (from pytablewriter->lm-eval>=0.4.2->litgpt[all]==0.5.3)\n",
            "  Downloading DataProperty-1.0.1-py3-none-any.whl.metadata (11 kB)\n",
            "Collecting mbstrdecoder<2,>=1.0.0 (from pytablewriter->lm-eval>=0.4.2->litgpt[all]==0.5.3)\n",
            "  Downloading mbstrdecoder-1.1.3-py3-none-any.whl.metadata (4.0 kB)\n",
            "Collecting pathvalidate<4,>=2.3.0 (from pytablewriter->lm-eval>=0.4.2->litgpt[all]==0.5.3)\n",
            "  Downloading pathvalidate-3.2.1-py3-none-any.whl.metadata (12 kB)\n",
            "Collecting tabledata<2,>=1.3.1 (from pytablewriter->lm-eval>=0.4.2->litgpt[all]==0.5.3)\n",
            "  Downloading tabledata-1.3.3-py3-none-any.whl.metadata (3.7 kB)\n",
            "Collecting tcolorpy<1,>=0.0.5 (from pytablewriter->lm-eval>=0.4.2->litgpt[all]==0.5.3)\n",
            "  Downloading tcolorpy-0.1.6-py3-none-any.whl.metadata (6.4 kB)\n",
            "Collecting typepy<2,>=1.3.2 (from typepy[datetime]<2,>=1.3.2->pytablewriter->lm-eval>=0.4.2->litgpt[all]==0.5.3)\n",
            "  Downloading typepy-1.3.2-py3-none-any.whl.metadata (9.3 kB)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch<=2.4.1,>=2.2.0->litgpt==0.5.3->litgpt[all]==0.5.3) (1.3.0)\n",
            "Requirement already satisfied: chardet<6,>=3.0.4 in /usr/local/lib/python3.10/dist-packages (from mbstrdecoder<2,>=1.0.0->pytablewriter->lm-eval>=0.4.2->litgpt[all]==0.5.3) (5.2.0)\n",
            "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,!=2.1.0,<3.0.0,>=1.7.4->fastapi>=0.100->litserve>=0.1.5->litgpt[all]==0.5.3) (0.7.0)\n",
            "Requirement already satisfied: pydantic-core==2.23.4 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,!=2.1.0,<3.0.0,>=1.7.4->fastapi>=0.100->litserve>=0.1.5->litgpt[all]==0.5.3) (2.23.4)\n",
            "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio->httpx->litserve>=0.1.5->litgpt[all]==0.5.3) (1.2.2)\n",
            "Downloading litgpt-0.5.3-py3-none-any.whl (176 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m176.3/176.3 kB\u001b[0m \u001b[31m8.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading bitsandbytes-0.42.0-py3-none-any.whl (105.0 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m105.0/105.0 MB\u001b[0m \u001b[31m22.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading lightning-2.4.0-py3-none-any.whl (810 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m811.0/811.0 kB\u001b[0m \u001b[31m49.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading litdata-0.2.17-py3-none-any.whl (125 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m125.7/125.7 kB\u001b[0m \u001b[31m10.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading datasets-3.1.0-py3-none-any.whl (480 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m480.6/480.6 kB\u001b[0m \u001b[31m34.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading jsonargparse-4.32.1-py3-none-any.whl (207 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.3/207.3 kB\u001b[0m \u001b[31m17.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading litserve-0.2.5-py3-none-any.whl (48 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m48.2/48.2 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading lm_eval-0.4.5-py3-none-any.whl (2.4 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.4/2.4 MB\u001b[0m \u001b[31m85.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading torch-2.4.1-cp310-cp310-manylinux1_x86_64.whl (797.1 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m797.1/797.1 MB\u001b[0m \u001b[31m2.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.6/410.6 MB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.1/14.1 MB\u001b[0m \u001b[31m105.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.7/23.7 MB\u001b[0m \u001b[31m81.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.6/823.6 kB\u001b[0m \u001b[31m48.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.6/121.6 MB\u001b[0m \u001b[31m18.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.5/56.5 MB\u001b[0m \u001b[31m38.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.2/124.2 MB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.0/196.0 MB\u001b[0m \u001b[31m5.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m176.2/176.2 MB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m99.1/99.1 kB\u001b[0m \u001b[31m8.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.4 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.4/209.4 MB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading torchmetrics-1.6.0-py3-none-any.whl (926 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m926.4/926.4 kB\u001b[0m \u001b[31m49.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading uvloop-0.21.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.8/3.8 MB\u001b[0m \u001b[31m100.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading zstandard-0.23.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.4 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.4/5.4 MB\u001b[0m \u001b[31m116.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m11.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading evaluate-0.4.3-py3-none-any.whl (84 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.0/84.0 kB\u001b[0m \u001b[31m8.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading fastapi-0.115.5-py3-none-any.whl (94 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m94.9/94.9 kB\u001b[0m \u001b[31m9.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (179 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.3/179.3 kB\u001b[0m \u001b[31m17.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading hf_transfer-0.1.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.6 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.6/3.6 MB\u001b[0m \u001b[31m99.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading lightning_utilities-0.11.9-py3-none-any.whl (28 kB)\n",
            "Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m13.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading pybind11-2.13.6-py3-none-any.whl (243 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m243.3/243.3 kB\u001b[0m \u001b[31m20.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading sacrebleu-2.4.3-py3-none-any.whl (103 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m104.0/104.0 kB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading typeshed_client-2.7.0-py3-none-any.whl (624 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m624.4/624.4 kB\u001b[0m \u001b[31m46.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading uvicorn-0.32.1-py3-none-any.whl (63 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m63.8/63.8 kB\u001b[0m \u001b[31m5.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading boto3-1.35.72-py3-none-any.whl (139 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m139.2/139.2 kB\u001b[0m \u001b[31m13.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading jsonlines-4.0.0-py3-none-any.whl (8.7 kB)\n",
            "Downloading pytablewriter-1.2.0-py3-none-any.whl (111 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m111.1/111.1 kB\u001b[0m \u001b[31m10.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading pytorch_lightning-2.4.0-py3-none-any.whl (815 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m815.2/815.2 kB\u001b[0m \u001b[31m48.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading tqdm_multiprocess-0.0.11-py3-none-any.whl (9.8 kB)\n",
            "Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading botocore-1.35.72-py3-none-any.whl (13.1 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.1/13.1 MB\u001b[0m \u001b[31m111.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading DataProperty-1.0.1-py3-none-any.whl (27 kB)\n",
            "Downloading httptools-0.6.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (442 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m442.1/442.1 kB\u001b[0m \u001b[31m33.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading jmespath-1.0.1-py3-none-any.whl (20 kB)\n",
            "Downloading mbstrdecoder-1.1.3-py3-none-any.whl (7.8 kB)\n",
            "Downloading pathvalidate-3.2.1-py3-none-any.whl (23 kB)\n",
            "Downloading python_dotenv-1.0.1-py3-none-any.whl (19 kB)\n",
            "Downloading s3transfer-0.10.4-py3-none-any.whl (83 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m83.2/83.2 kB\u001b[0m \u001b[31m7.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading starlette-0.41.3-py3-none-any.whl (73 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m73.2/73.2 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading tabledata-1.3.3-py3-none-any.whl (11 kB)\n",
            "Downloading tcolorpy-0.1.6-py3-none-any.whl (8.1 kB)\n",
            "Downloading typepy-1.3.2-py3-none-any.whl (31 kB)\n",
            "Downloading watchfiles-1.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (442 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m442.6/442.6 kB\u001b[0m \u001b[31m32.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading websockets-14.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (168 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.2/168.2 kB\u001b[0m \u001b[31m15.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)\n",
            "Downloading portalocker-3.0.0-py3-none-any.whl (19 kB)\n",
            "Building wheels for collected packages: rouge-score, sqlitedict, word2number\n",
            "  Building wheel for rouge-score (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for rouge-score: filename=rouge_score-0.1.2-py3-none-any.whl size=24935 sha256=71b69f6a4a44643c3f5ec76a03ac4bc9f71988909fa09aac8b0d40458d19d271\n",
            "  Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4\n",
            "  Building wheel for sqlitedict (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for sqlitedict: filename=sqlitedict-2.1.0-py3-none-any.whl size=16864 sha256=8e91876754c6bdfac16f5826524eed158afe245a07720f9db280fd5573e6ecfa\n",
            "  Stored in directory: /root/.cache/pip/wheels/79/d6/e7/304e0e6cb2221022c26d8161f7c23cd4f259a9e41e8bbcfabd\n",
            "  Building wheel for word2number (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for word2number: filename=word2number-1.1-py3-none-any.whl size=5568 sha256=5199840c8aa44227ab60f898b4da407553ff13af6998bfa831acbca2052013ca\n",
            "  Stored in directory: /root/.cache/pip/wheels/84/ff/26/d3cfbd971e96c5aa3737ecfced81628830d7359b55fbb8ca3b\n",
            "Successfully built rouge-score sqlitedict word2number\n",
            "Installing collected packages: word2number, sqlitedict, zstandard, xxhash, websockets, uvloop, uvicorn, typeshed-client, triton, tcolorpy, python-dotenv, pybind11, portalocker, pathvalidate, nvidia-nvtx-cu12, nvidia-nccl-cu12, nvidia-cusparse-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, mbstrdecoder, lightning-utilities, jsonlines, jsonargparse, jmespath, httptools, hf-transfer, fsspec, dill, colorama, watchfiles, typepy, tqdm-multiprocess, starlette, sacrebleu, rouge-score, nvidia-cusolver-cu12, nvidia-cudnn-cu12, multiprocess, botocore, bitsandbytes, torch, s3transfer, fastapi, torchmetrics, litserve, DataProperty, boto3, tabledata, pytorch-lightning, litdata, datasets, pytablewriter, lightning, evaluate, lm-eval, litgpt\n",
            "  Attempting uninstall: nvidia-nccl-cu12\n",
            "    Found existing installation: nvidia-nccl-cu12 2.23.4\n",
            "    Uninstalling nvidia-nccl-cu12-2.23.4:\n",
            "      Successfully uninstalled nvidia-nccl-cu12-2.23.4\n",
            "  Attempting uninstall: nvidia-cusparse-cu12\n",
            "    Found existing installation: nvidia-cusparse-cu12 12.5.4.2\n",
            "    Uninstalling nvidia-cusparse-cu12-12.5.4.2:\n",
            "      Successfully uninstalled nvidia-cusparse-cu12-12.5.4.2\n",
            "  Attempting uninstall: nvidia-curand-cu12\n",
            "    Found existing installation: nvidia-curand-cu12 10.3.7.77\n",
            "    Uninstalling nvidia-curand-cu12-10.3.7.77:\n",
            "      Successfully uninstalled nvidia-curand-cu12-10.3.7.77\n",
            "  Attempting uninstall: nvidia-cufft-cu12\n",
            "    Found existing installation: nvidia-cufft-cu12 11.3.0.4\n",
            "    Uninstalling nvidia-cufft-cu12-11.3.0.4:\n",
            "      Successfully uninstalled nvidia-cufft-cu12-11.3.0.4\n",
            "  Attempting uninstall: nvidia-cuda-runtime-cu12\n",
            "    Found existing installation: nvidia-cuda-runtime-cu12 12.6.77\n",
            "    Uninstalling nvidia-cuda-runtime-cu12-12.6.77:\n",
            "      Successfully uninstalled nvidia-cuda-runtime-cu12-12.6.77\n",
            "  Attempting uninstall: nvidia-cuda-cupti-cu12\n",
            "    Found existing installation: nvidia-cuda-cupti-cu12 12.6.80\n",
            "    Uninstalling nvidia-cuda-cupti-cu12-12.6.80:\n",
            "      Successfully uninstalled nvidia-cuda-cupti-cu12-12.6.80\n",
            "  Attempting uninstall: nvidia-cublas-cu12\n",
            "    Found existing installation: nvidia-cublas-cu12 12.6.3.3\n",
            "    Uninstalling nvidia-cublas-cu12-12.6.3.3:\n",
            "      Successfully uninstalled nvidia-cublas-cu12-12.6.3.3\n",
            "  Attempting uninstall: fsspec\n",
            "    Found existing installation: fsspec 2024.10.0\n",
            "    Uninstalling fsspec-2024.10.0:\n",
            "      Successfully uninstalled fsspec-2024.10.0\n",
            "  Attempting uninstall: nvidia-cusolver-cu12\n",
            "    Found existing installation: nvidia-cusolver-cu12 11.7.1.2\n",
            "    Uninstalling nvidia-cusolver-cu12-11.7.1.2:\n",
            "      Successfully uninstalled nvidia-cusolver-cu12-11.7.1.2\n",
            "  Attempting uninstall: nvidia-cudnn-cu12\n",
            "    Found existing installation: nvidia-cudnn-cu12 9.5.1.17\n",
            "    Uninstalling nvidia-cudnn-cu12-9.5.1.17:\n",
            "      Successfully uninstalled nvidia-cudnn-cu12-9.5.1.17\n",
            "  Attempting uninstall: torch\n",
            "    Found existing installation: torch 2.5.1+cu121\n",
            "    Uninstalling torch-2.5.1+cu121:\n",
            "      Successfully uninstalled torch-2.5.1+cu121\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which is incompatible.\n",
            "torchaudio 2.5.1+cu121 requires torch==2.5.1, but you have torch 2.4.1 which is incompatible.\n",
            "torchvision 0.20.1+cu121 requires torch==2.5.1, but you have torch 2.4.1 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0mSuccessfully installed DataProperty-1.0.1 bitsandbytes-0.42.0 boto3-1.35.72 botocore-1.35.72 colorama-0.4.6 datasets-3.1.0 dill-0.3.8 evaluate-0.4.3 fastapi-0.115.5 fsspec-2024.9.0 hf-transfer-0.1.8 httptools-0.6.4 jmespath-1.0.1 jsonargparse-4.32.1 jsonlines-4.0.0 lightning-2.4.0 lightning-utilities-0.11.9 litdata-0.2.17 litgpt-0.5.3 litserve-0.2.5 lm-eval-0.4.5 mbstrdecoder-1.1.3 multiprocess-0.70.16 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvtx-cu12-12.1.105 pathvalidate-3.2.1 portalocker-3.0.0 pybind11-2.13.6 pytablewriter-1.2.0 python-dotenv-1.0.1 pytorch-lightning-2.4.0 rouge-score-0.1.2 s3transfer-0.10.4 sacrebleu-2.4.3 sqlitedict-2.1.0 starlette-0.41.3 tabledata-1.3.3 tcolorpy-0.1.6 torch-2.4.1 torchmetrics-1.6.0 tqdm-multiprocess-0.0.11 triton-3.0.0 typepy-1.3.2 typeshed-client-2.7.0 uvicorn-0.32.1 uvloop-0.21.0 watchfiles-1.0.0 websockets-14.1 word2number-1.1 xxhash-3.5.0 zstandard-0.23.0\n"
          ]
        }
      ],
      "source": [
        "!pip install \"litgpt[all]==0.5.3\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4l--esz6VnHf"
      },
      "source": [
        "Installing `litgpt` downgrades the pre-installed PyTorch version to 2.4.1, causing compatibility issues with `torchvision` and `torchaudio`. To avoid errors when pushing the fine-tuned model to the Hugging Face Hub, you must install versions of `torchvision` and `torchaudio` compatible with PyTorch 2.4.1.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "id": "6tIrBREKLxlL"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Found existing installation: torchvision 0.20.1+cu121\n",
            "Uninstalling torchvision-0.20.1+cu121:\n",
            "  Successfully uninstalled torchvision-0.20.1+cu121\n",
            "Found existing installation: torchaudio 2.5.1+cu121\n",
            "Uninstalling torchaudio-2.5.1+cu121:\n",
            "  Successfully uninstalled torchaudio-2.5.1+cu121\n",
            "Looking in indexes: https://download.pytorch.org/whl/cu121\n",
            "Collecting torchaudio==2.4.1\n",
            "  Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.4.1%2Bcu121-cp310-cp310-linux_x86_64.whl (3.4 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m47.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting torchvision==0.19.1\n",
            "  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.19.1%2Bcu121-cp310-cp310-linux_x86_64.whl (7.1 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.1/7.1 MB\u001b[0m \u001b[31m62.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: torch==2.4.1 in /usr/local/lib/python3.10/dist-packages (from torchaudio==2.4.1) (2.4.1)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision==0.19.1) (1.26.4)\n",
            "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision==0.19.1) (11.0.0)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch==2.4.1->torchaudio==2.4.1) (3.16.1)\n",
            "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.4.1->torchaudio==2.4.1) (4.12.2)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch==2.4.1->torchaudio==2.4.1) (1.13.1)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch==2.4.1->torchaudio==2.4.1) (3.4.2)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch==2.4.1->torchaudio==2.4.1) (3.1.4)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch==2.4.1->torchaudio==2.4.1) (2024.9.0)\n",
            "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch==2.4.1->torchaudio==2.4.1) (12.1.105)\n",
            "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch==2.4.1->torchaudio==2.4.1) (12.1.105)\n",
            "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch==2.4.1->torchaudio==2.4.1) (12.1.105)\n",
            "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.10/dist-packages (from torch==2.4.1->torchaudio==2.4.1) (9.1.0.70)\n",
            "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch==2.4.1->torchaudio==2.4.1) (12.1.3.1)\n",
            "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch==2.4.1->torchaudio==2.4.1) (11.0.2.54)\n",
            "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch==2.4.1->torchaudio==2.4.1) (10.3.2.106)\n",
            "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch==2.4.1->torchaudio==2.4.1) (11.4.5.107)\n",
            "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch==2.4.1->torchaudio==2.4.1) (12.1.0.106)\n",
            "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.10/dist-packages (from torch==2.4.1->torchaudio==2.4.1) (2.20.5)\n",
            "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch==2.4.1->torchaudio==2.4.1) (12.1.105)\n",
            "Requirement already satisfied: triton==3.0.0 in /usr/local/lib/python3.10/dist-packages (from torch==2.4.1->torchaudio==2.4.1) (3.0.0)\n",
            "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.4.1->torchaudio==2.4.1) (12.6.77)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch==2.4.1->torchaudio==2.4.1) (3.0.2)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch==2.4.1->torchaudio==2.4.1) (1.3.0)\n",
            "Installing collected packages: torchvision, torchaudio\n",
            "Successfully installed torchaudio-2.4.1+cu121 torchvision-0.19.1+cu121\n"
          ]
        }
      ],
      "source": [
        "!pip uninstall -y torchvision torchaudio\n",
        "!pip install 'torchaudio==2.4.1' 'torchvision==0.19.1' --index-url https://download.pytorch.org/whl/cu121"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_wsT7Eb_1esZ"
      },
      "source": [
        "## Overview\n",
        "\n",
        "LitGPT supports working with multiple local LLMs. It implements LLMs from scratch without any abstractions, giving users full control.\n",
        "\n",
        "In this notebook, you'll implement the following workflows on Gemma 2 using LitGPT:\n",
        "\n",
        "1. Fine-tune Gemma 2 on a small subset of the Alpaca dataset.\n",
        "2. Perform inference using the fine-tuned model.\n",
        "3. Deploy the fine-tuned model and send inference requests to the server using Python `requests`.\n",
        "4. Upload the fine-tuned model to the Hugging Face Hub repository.\n",
        "\n",
        "In this notebook, you'll use LitGPT's command-line interface to implement the aforementioned tasks. LitGPT also has an experimental Python API. You can explore its capabilities by visiting the [LitGPT Python API tutorial](https://github.com/Lightning-AI/litgpt/blob/main/tutorials/python-api.md)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tMfQBPRYsknZ"
      },
      "source": [
        "## 1. Fine-tune Gemma 2 using LitGPT\n",
        "\n",
        "In this section, you will fine-tune Gemma 2 on a small subset of the Alpaca dataset using the LitGPT command-line interface.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cJB1fv7fCYbv"
      },
      "source": [
        "### Download the Gemma 2 model\n",
        "\n",
        "LitGPT supports a variety of open source models including Gemma. To list the supported models, run the following command:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "X-Qk8xegz3nF"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Please specify --repo_id <repo_id>. Available values:\n",
            "codellama/CodeLlama-13b-hf\n",
            "codellama/CodeLlama-13b-Instruct-hf\n",
            "codellama/CodeLlama-13b-Python-hf\n",
            "codellama/CodeLlama-34b-hf\n",
            "codellama/CodeLlama-34b-Instruct-hf\n",
            "codellama/CodeLlama-34b-Python-hf\n",
            "codellama/CodeLlama-70b-hf\n",
            "codellama/CodeLlama-70b-Instruct-hf\n",
            "codellama/CodeLlama-70b-Python-hf\n",
            "codellama/CodeLlama-7b-hf\n",
            "codellama/CodeLlama-7b-Instruct-hf\n",
            "codellama/CodeLlama-7b-Python-hf\n",
            "databricks/dolly-v2-12b\n",
            "databricks/dolly-v2-3b\n",
            "databricks/dolly-v2-7b\n",
            "EleutherAI/pythia-1.4b\n",
            "EleutherAI/pythia-1.4b-deduped\n",
            "EleutherAI/pythia-12b\n",
            "EleutherAI/pythia-12b-deduped\n",
            "EleutherAI/pythia-14m\n",
            "EleutherAI/pythia-160m\n",
            "EleutherAI/pythia-160m-deduped\n",
            "EleutherAI/pythia-1b\n",
            "EleutherAI/pythia-1b-deduped\n",
            "EleutherAI/pythia-2.8b\n",
            "EleutherAI/pythia-2.8b-deduped\n",
            "EleutherAI/pythia-31m\n",
            "EleutherAI/pythia-410m\n",
            "EleutherAI/pythia-410m-deduped\n",
            "EleutherAI/pythia-6.9b\n",
            "EleutherAI/pythia-6.9b-deduped\n",
            "EleutherAI/pythia-70m\n",
            "EleutherAI/pythia-70m-deduped\n",
            "garage-bAInd/Camel-Platypus2-13B\n",
            "garage-bAInd/Camel-Platypus2-70B\n",
            "garage-bAInd/Platypus-30B\n",
            "garage-bAInd/Platypus2-13B\n",
            "garage-bAInd/Platypus2-70B\n",
            "garage-bAInd/Platypus2-70B-instruct\n",
            "garage-bAInd/Platypus2-7B\n",
            "garage-bAInd/Stable-Platypus2-13B\n",
            "google/codegemma-7b-it\n",
            "google/gemma-2-27b\n",
            "google/gemma-2-27b-it\n",
            "google/gemma-2-2b\n",
            "google/gemma-2-2b-it\n",
            "google/gemma-2-9b\n",
            "google/gemma-2-9b-it\n",
            "google/gemma-2b\n",
            "google/gemma-2b-it\n",
            "google/gemma-7b\n",
            "google/gemma-7b-it\n",
            "h2oai/h2o-danube2-1.8b-chat\n",
            "keeeeenw/MicroLlama\n",
            "lmsys/longchat-13b-16k\n",
            "lmsys/longchat-7b-16k\n",
            "lmsys/vicuna-13b-v1.3\n",
            "lmsys/vicuna-13b-v1.5\n",
            "lmsys/vicuna-13b-v1.5-16k\n",
            "lmsys/vicuna-33b-v1.3\n",
            "lmsys/vicuna-7b-v1.3\n",
            "lmsys/vicuna-7b-v1.5\n",
            "lmsys/vicuna-7b-v1.5-16k\n",
            "meta-llama/Llama-2-13b-chat-hf\n",
            "meta-llama/Llama-2-13b-hf\n",
            "meta-llama/Llama-2-70b-chat-hf\n",
            "meta-llama/Llama-2-70b-hf\n",
            "meta-llama/Llama-2-7b-chat-hf\n",
            "meta-llama/Llama-2-7b-hf\n",
            "meta-llama/Llama-3.2-1B\n",
            "meta-llama/Llama-3.2-1B-Instruct\n",
            "meta-llama/Llama-3.2-3B\n",
            "meta-llama/Llama-3.2-3B-Instruct\n",
            "meta-llama/Meta-Llama-3-70B\n",
            "meta-llama/Meta-Llama-3-70B-Instruct\n",
            "meta-llama/Meta-Llama-3-8B\n",
            "meta-llama/Meta-Llama-3-8B-Instruct\n",
            "meta-llama/Meta-Llama-3.1-405B\n",
            "meta-llama/Meta-Llama-3.1-405B-Instruct\n",
            "meta-llama/Meta-Llama-3.1-70B\n",
            "meta-llama/Meta-Llama-3.1-70B-Instruct\n",
            "meta-llama/Meta-Llama-3.1-8B\n",
            "meta-llama/Meta-Llama-3.1-8B-Instruct\n",
            "microsoft/phi-1_5\n",
            "microsoft/phi-2\n",
            "microsoft/Phi-3-mini-128k-instruct\n",
            "microsoft/Phi-3-mini-4k-instruct\n",
            "microsoft/Phi-3.5-mini-instruct\n",
            "mistralai/mathstral-7B-v0.1\n",
            "mistralai/Mistral-7B-Instruct-v0.1\n",
            "mistralai/Mistral-7B-Instruct-v0.2\n",
            "mistralai/Mistral-7B-Instruct-v0.3\n",
            "mistralai/Mistral-7B-v0.1\n",
            "mistralai/Mistral-7B-v0.3\n",
            "mistralai/Mistral-Large-Instruct-2407\n",
            "mistralai/Mixtral-8x7B-Instruct-v0.1\n",
            "mistralai/Mixtral-8x7B-v0.1\n",
            "NousResearch/Nous-Hermes-13b\n",
            "NousResearch/Nous-Hermes-llama-2-7b\n",
            "NousResearch/Nous-Hermes-Llama2-13b\n",
            "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF\n",
            "openlm-research/open_llama_13b\n",
            "openlm-research/open_llama_3b\n",
            "openlm-research/open_llama_7b\n",
            "stabilityai/FreeWilly2\n",
            "stabilityai/stable-code-3b\n",
            "stabilityai/stablecode-completion-alpha-3b\n",
            "stabilityai/stablecode-completion-alpha-3b-4k\n",
            "stabilityai/stablecode-instruct-alpha-3b\n",
            "stabilityai/stablelm-3b-4e1t\n",
            "stabilityai/stablelm-base-alpha-3b\n",
            "stabilityai/stablelm-base-alpha-7b\n",
            "stabilityai/stablelm-tuned-alpha-3b\n",
            "stabilityai/stablelm-tuned-alpha-7b\n",
            "stabilityai/stablelm-zephyr-3b\n",
            "tiiuae/falcon-180B\n",
            "tiiuae/falcon-180B-chat\n",
            "tiiuae/falcon-40b\n",
            "tiiuae/falcon-40b-instruct\n",
            "tiiuae/falcon-7b\n",
            "tiiuae/falcon-7b-instruct\n",
            "TinyLlama/TinyLlama-1.1B-Chat-v1.0\n",
            "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n",
            "togethercomputer/LLaMA-2-7B-32K\n",
            "togethercomputer/RedPajama-INCITE-7B-Base\n",
            "togethercomputer/RedPajama-INCITE-7B-Chat\n",
            "togethercomputer/RedPajama-INCITE-7B-Instruct\n",
            "togethercomputer/RedPajama-INCITE-Base-3B-v1\n",
            "togethercomputer/RedPajama-INCITE-Base-7B-v0.1\n",
            "togethercomputer/RedPajama-INCITE-Chat-3B-v1\n",
            "togethercomputer/RedPajama-INCITE-Chat-7B-v0.1\n",
            "togethercomputer/RedPajama-INCITE-Instruct-3B-v1\n",
            "togethercomputer/RedPajama-INCITE-Instruct-7B-v0.1\n",
            "Trelis/Llama-2-7b-chat-hf-function-calling-v2\n",
            "unsloth/Mistral-7B-v0.2\n"
          ]
        }
      ],
      "source": [
        "!litgpt download list"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mpQk2Hakz-aS"
      },
      "source": [
        "In this notebook, you will use Gemma 2's 2b model. Download the model weights using the following command:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "EuJFGgdgz7jk"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Setting HF_HUB_ENABLE_HF_TRANSFER=1\n",
            "config.json: 100% 818/818 [00:00<00:00, 5.36MB/s]\n",
            "generation_config.json: 100% 168/168 [00:00<00:00, 1.06MB/s]\n",
            "model-00001-of-00003.safetensors: 100% 4.99G/4.99G [00:12<00:00, 400MB/s]\n",
            "model-00002-of-00003.safetensors: 100% 4.98G/4.98G [00:14<00:00, 346MB/s]\n",
            "model-00003-of-00003.safetensors: 100% 481M/481M [00:01<00:00, 458MB/s]\n",
            "model.safetensors.index.json: 100% 24.2k/24.2k [00:00<00:00, 48.1MB/s]\n",
            "tokenizer.json: 100% 17.5M/17.5M [00:00<00:00, 42.6MB/s]\n",
            "tokenizer.model: 100% 4.24M/4.24M [00:00<00:00, 48.3MB/s]\n",
            "tokenizer_config.json: 100% 46.4k/46.4k [00:00<00:00, 48.9MB/s]\n",
            "Converting .safetensor files to PyTorch binaries (.bin)\n",
            "checkpoints/google/gemma-2-2b/model-00003-of-00003.safetensors --> checkpoints/google/gemma-2-2b/model-00003-of-00003.bin\n",
            "checkpoints/google/gemma-2-2b/model-00001-of-00003.safetensors --> checkpoints/google/gemma-2-2b/model-00001-of-00003.bin\n",
            "checkpoints/google/gemma-2-2b/model-00002-of-00003.safetensors --> checkpoints/google/gemma-2-2b/model-00002-of-00003.bin\n",
            "Converting checkpoint files to LitGPT format.\n",
            "{'checkpoint_dir': PosixPath('checkpoints/google/gemma-2-2b'),\n",
            " 'debug_mode': False,\n",
            " 'dtype': None,\n",
            " 'model_name': None}\n",
            "Loading weights: model-00003-of-00003.bin: 100% 100.0/100 [00:25<00:00,  3.95it/s]\n",
            "Saving converted checkpoint to checkpoints/google/gemma-2-2b\n"
          ]
        }
      ],
      "source": [
        "!litgpt download google/gemma-2-2b"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KzLCcRBB0Zrl"
      },
      "source": [
        "### Fine-tune Gemma 2 on Alpaca dataset\n",
        "You will now fine-tune Gemma 2 on a subset of the Alpaca dataset.\n",
        "\n",
        "\n",
        "**Alpaca dataset**\n",
        "\n",
        "LitGPT supports instruction-tuning models on many popular open-source datasets, like Alpaca, Dolly, FLAN, etc., using a simple command-line interface. No need to download or prepare datasets separately; LitGPT handles this automatically.\n",
        "\n",
        "The full [Alpaca](https://crfm.stanford.edu/2023/03/13/alpaca.html) dataset contains 52,000 instruction-response pairs, suitable for fine-tuning language models to follow instructions. However, for this task, you'll use a smaller subset of 2000 samples, [Alpaca2k](https://github.com/Lightning-AI/litgpt/blob/7449dad90740c4b0947a6ccb474b869ef969e110/tutorials/prepare_dataset.md#alpaca-2k).\n",
        "\n",
        "Credits:\n",
        "[mhenrichsen/alpaca_2k_test](https://huggingface.co/datasets/mhenrichsen/alpaca_2k_test) (This dataset provides the 2,000-sample Alpaca2k subset supported by LitGPT).\n",
        "\n",
        "**LoRA fine-tuning**\n",
        "\n",
        "LitGPT supports various fine-tuning methods, including full fine-tuning, LoRA, QLoRA, and adapter fine-tuning.\n",
        "\n",
        "While full fine-tuning trains all model weight parameters, it's memory-intensive. For this reason, you'll use the LoRA technique to fine-tune Gemma 2 in this notebook.\n",
        "\n",
        "LoRA (Low-Rank Adaptation) is a technique that freezes the original model's weights and introduces small, trainable parameter matrices for each layer. This significantly reduces the number of trainable parameters, leading to lower computational and memory requirements during fine-tuning.\n",
        "\n",
        "LoRA reduces the storage requirements of LLMs without increasing inference latency.\n",
        "\n",
        "\n",
        "You can read more about LoRA by visiting the [official LoRA Github repository](https://github.com/microsoft/LoRA).\n",
        "\n",
        "\n",
        "**Command-line arguments**\n",
        "\n",
        "Use the `litgpt finetune_lora` command to fine-tune Gemma 2.\n",
        "The following command line arguments are specified:\n",
        "1. `--data`: Specifies the dataset to be used for fine-tuning. While LitGPT supports various datasets, you'll be using `Alpaca2k` for this task. For more details about the supported datasets and data-specific command line arguments please refer to the [LitGPT preparing datasets tutorial](https://github.com/Lightning-AI/litgpt/blob/main/tutorials/prepare_dataset.md).\n",
        "2. `--train.max_seq_length`: The maximum sequence length of the tokenized training samples. Samples that exceed this sequence length are truncated leading to a reduction in computational resource requirements. The maximum sequence length can be determined from the distribution of the training samples. In this tutorial, this value is set to 512. You can read more about this parameter in the [LitGPT preparing datasets tutorial](https://github.com/Lightning-AI/litgpt/blob/main/tutorials/prepare_dataset.md#truncating-datasets). You can explore the distribution of the `Alpaca2k` dataset in the [Alpaca2k section](https://github.com/Lightning-AI/litgpt/blob/main/tutorials/prepare_dataset.md#alpaca-2k) of this tutorial.\n",
        "3. `--train.micro_batch_size`: Determines the number of samples processed per iteration. This value is set to 2 in this tutorial to avoid out-of-memory errors on the T4 GPU. You can adjust this based on your GPU's memory capacity.\n",
        "4. `--train.epochs`: Specifies the number of epochs to fine-tune the model for. In this example, the model will be fine-tuned for one epoch. For better results, you can increase the number of epochs.\n",
        "5. `--out_dir`: Specifies the directory where checkpoints are periodically saved during fine-tuning.\n",
        "6. `--precision`: Sets the precision to `bf16-true`. Using lower precision (bf16) reduces memory usage compared to 32-bit precision. You can find more details on this in the [LitGPT handling out-of-memory errors guide](https://github.com/Lightning-AI/litgpt/blob/7449dad90740c4b0947a6ccb474b869ef969e110/tutorials/oom.md#use-lower-precision).\n",
        "\n",
        "In addition to these parameters, you can customize other train, evaluation, LoRA, and dataset-specific settings in LitGPT. Run `litgpt finetune_lora --help` to see all configurable parameters.\n",
        "\n",
        "Note: Fine-tuning Gemma 2 on Alpaca 2K with the specified hyperparameters takes about 45-50 minutes on a T4 GPU. For better results, you can adjust the training configuration, fine-tune for longer periods or use the full Alpaca dataset.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "rbHcacd-0Y-r"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "{'access_token': None,\n",
            " 'checkpoint_dir': PosixPath('checkpoints/google/gemma-2-2b'),\n",
            " 'data': Alpaca2k(mask_prompt=False,\n",
            "                  val_split_fraction=0.05,\n",
            "                  prompt_style=<litgpt.prompts.Alpaca object at 0x7cc16adf00a0>,\n",
            "                  ignore_index=-100,\n",
            "                  seed=42,\n",
            "                  num_workers=4,\n",
            "                  download_dir=PosixPath('data/alpaca2k')),\n",
            " 'devices': 1,\n",
            " 'eval': EvalArgs(interval=100,\n",
            "                  max_new_tokens=100,\n",
            "                  max_iters=100,\n",
            "                  initial_validation=False,\n",
            "                  final_validation=True,\n",
            "                  evaluate_example='first'),\n",
            " 'logger_name': 'csv',\n",
            " 'lora_alpha': 16,\n",
            " 'lora_dropout': 0.05,\n",
            " 'lora_head': False,\n",
            " 'lora_key': False,\n",
            " 'lora_mlp': False,\n",
            " 'lora_projection': False,\n",
            " 'lora_query': True,\n",
            " 'lora_r': 8,\n",
            " 'lora_value': True,\n",
            " 'num_nodes': 1,\n",
            " 'optimizer': 'AdamW',\n",
            " 'out_dir': PosixPath('out/lit-finetuned/gemma-2-alpaca-it'),\n",
            " 'precision': 'bf16-true',\n",
            " 'quantize': None,\n",
            " 'seed': 1337,\n",
            " 'train': TrainArgs(save_interval=1000,\n",
            "                    log_interval=1,\n",
            "                    global_batch_size=16,\n",
            "                    micro_batch_size=2,\n",
            "                    lr_warmup_steps=100,\n",
            "                    lr_warmup_fraction=None,\n",
            "                    epochs=1,\n",
            "                    max_tokens=None,\n",
            "                    max_steps=None,\n",
            "                    max_seq_length=512,\n",
            "                    tie_embeddings=None,\n",
            "                    max_norm=None,\n",
            "                    min_lr=6e-05)}\n",
            "README.md: 100% 28.0/28.0 [00:00<00:00, 196kB/s]\n",
            "alpaca_2000.parquet: 100% 1.76M/1.76M [00:00<00:00, 34.7MB/s]\n",
            "Generating train split: 100% 2000/2000 [00:00<00:00, 51351.07 examples/s]\n",
            "Seed set to 1337\n",
            "Number of trainable parameters: 1,597,440\n",
            "Number of non-trainable parameters: 3,204,165,888\n",
            "The longest sequence length in the train data is 512, the model's maximum sequence length is 512 and context length is 8192\n",
            "Verifying settings ...\n",
            "/usr/local/lib/python3.10/dist-packages/lightning/fabric/utilities/throughput.py:584: 'Tesla T4' does not support torch.bfloat16\n",
            "Epoch 1 | iter 1 step 0 | loss train: 2.538, val: n/a | iter time: 3426.19 ms\n",
            "Epoch 1 | iter 2 step 0 | loss train: 3.047, val: n/a | iter time: 3032.69 ms\n",
            "Epoch 1 | iter 3 step 0 | loss train: 3.774, val: n/a | iter time: 1871.92 ms\n",
            "Epoch 1 | iter 4 step 0 | loss train: 3.515, val: n/a | iter time: 2676.23 ms\n",
            "Epoch 1 | iter 5 step 0 | loss train: 3.235, val: n/a | iter time: 4510.61 ms\n",
            "Epoch 1 | iter 6 step 0 | loss train: 3.477, val: n/a | iter time: 2486.04 ms\n",
            "Epoch 1 | iter 7 step 0 | loss train: 3.572, val: n/a | iter time: 1688.35 ms\n",
            "Epoch 1 | iter 8 step 1 | loss train: 3.572, val: n/a | iter time: 4025.33 ms (step)\n",
            "Epoch 1 | iter 9 step 1 | loss train: 3.958, val: n/a | iter time: 1037.77 ms\n",
            "Epoch 1 | iter 10 step 1 | loss train: 3.900, val: n/a | iter time: 2270.59 ms\n",
            "Epoch 1 | iter 11 step 1 | loss train: 3.688, val: n/a | iter time: 3120.65 ms\n",
            "Epoch 1 | iter 12 step 1 | loss train: 3.937, val: n/a | iter time: 2551.25 ms\n",
            "Epoch 1 | iter 13 step 1 | loss train: 4.113, val: n/a | iter time: 2037.06 ms\n",
            "Epoch 1 | iter 14 step 1 | loss train: 4.034, val: n/a | iter time: 2689.87 ms\n",
            "Epoch 1 | iter 15 step 1 | loss train: 4.140, val: n/a | iter time: 1750.49 ms\n",
            "Epoch 1 | iter 16 step 2 | loss train: 4.339, val: n/a | iter time: 1682.45 ms (step)\n",
            "Epoch 1 | iter 17 step 2 | loss train: 4.200, val: n/a | iter time: 1730.35 ms\n",
            "Epoch 1 | iter 18 step 2 | loss train: 4.550, val: n/a | iter time: 1169.61 ms\n",
            "Epoch 1 | iter 19 step 2 | loss train: 4.930, val: n/a | iter time: 875.50 ms\n",
            "Epoch 1 | iter 20 step 2 | loss train: 4.986, val: n/a | iter time: 890.26 ms\n",
            "Epoch 1 | iter 21 step 2 | loss train: 5.325, val: n/a | iter time: 1105.60 ms\n",
            "Epoch 1 | iter 22 step 2 | loss train: 5.268, val: n/a | iter time: 3885.93 ms\n",
            "Epoch 1 | iter 23 step 2 | loss train: 5.348, val: n/a | iter time: 1084.29 ms\n",
            "Epoch 1 | iter 24 step 3 | loss train: 5.585, val: n/a | iter time: 812.88 ms (step)\n",
            "Epoch 1 | iter 25 step 3 | loss train: 5.616, val: n/a | iter time: 1505.26 ms\n",
            "Epoch 1 | iter 26 step 3 | loss train: 5.346, val: n/a | iter time: 2906.83 ms\n",
            "Epoch 1 | iter 27 step 3 | loss train: 4.925, val: n/a | iter time: 2905.13 ms\n",
            "Epoch 1 | iter 28 step 3 | loss train: 5.016, val: n/a | iter time: 1111.55 ms\n",
            "Epoch 1 | iter 29 step 3 | loss train: 5.079, val: n/a | iter time: 898.65 ms\n",
            "Epoch 1 | iter 30 step 3 | loss train: 5.060, val: n/a | iter time: 3536.16 ms\n",
            "Epoch 1 | iter 31 step 3 | loss train: 4.925, val: n/a | iter time: 1738.34 ms\n",
            "Epoch 1 | iter 32 step 4 | loss train: 4.340, val: n/a | iter time: 3661.17 ms (step)\n",
            "Epoch 1 | iter 33 step 4 | loss train: 4.162, val: n/a | iter time: 3397.82 ms\n",
            "Epoch 1 | iter 34 step 4 | loss train: 4.089, val: n/a | iter time: 3029.18 ms\n",
            "Epoch 1 | iter 35 step 4 | loss train: 4.163, val: n/a | iter time: 2638.14 ms\n",
            "Epoch 1 | iter 36 step 4 | loss train: 3.980, val: n/a | iter time: 1873.76 ms\n",
            "Epoch 1 | iter 37 step 4 | loss train: 3.523, val: n/a | iter time: 4308.46 ms\n",
            "Epoch 1 | iter 38 step 4 | loss train: 3.423, val: n/a | iter time: 3815.22 ms\n",
            "Epoch 1 | iter 39 step 4 | loss train: 3.829, val: n/a | iter time: 926.98 ms\n",
            "Epoch 1 | iter 40 step 5 | loss train: 4.041, val: n/a | iter time: 2546.43 ms (step)\n",
            "Epoch 1 | iter 41 step 5 | loss train: 3.956, val: n/a | iter time: 5111.39 ms\n",
            "Epoch 1 | iter 42 step 5 | loss train: 4.085, val: n/a | iter time: 4445.24 ms\n",
            "Epoch 1 | iter 43 step 5 | loss train: 4.349, val: n/a | iter time: 1876.52 ms\n",
            "Epoch 1 | iter 44 step 5 | loss train: 4.161, val: n/a | iter time: 3277.21 ms\n",
            "Epoch 1 | iter 45 step 5 | loss train: 4.164, val: n/a | iter time: 2901.50 ms\n",
            "Epoch 1 | iter 46 step 5 | loss train: 4.242, val: n/a | iter time: 2964.04 ms\n",
            "Epoch 1 | iter 47 step 5 | loss train: 3.836, val: n/a | iter time: 2257.58 ms\n",
            "Epoch 1 | iter 48 step 6 | loss train: 4.033, val: n/a | iter time: 1580.90 ms (step)\n",
            "Epoch 1 | iter 49 step 6 | loss train: 4.351, val: n/a | iter time: 1691.60 ms\n",
            "Epoch 1 | iter 50 step 6 | loss train: 4.254, val: n/a | iter time: 4500.43 ms\n",
            "Epoch 1 | iter 51 step 6 | loss train: 4.151, val: n/a | iter time: 1712.96 ms\n",
            "Epoch 1 | iter 52 step 6 | loss train: 4.359, val: n/a | iter time: 2018.09 ms\n",
            "Epoch 1 | iter 53 step 6 | loss train: 4.394, val: n/a | iter time: 3811.68 ms\n",
            "Epoch 1 | iter 54 step 6 | loss train: 4.606, val: n/a | iter time: 1335.03 ms\n",
            "Epoch 1 | iter 55 step 6 | loss train: 4.310, val: n/a | iter time: 5596.09 ms\n",
            "Epoch 1 | iter 56 step 7 | loss train: 3.802, val: n/a | iter time: 5741.61 ms (step)\n",
            "Epoch 1 | iter 57 step 7 | loss train: 3.878, val: n/a | iter time: 978.70 ms\n",
            "Epoch 1 | iter 58 step 7 | loss train: 3.914, val: n/a | iter time: 3662.58 ms\n",
            "Epoch 1 | iter 59 step 7 | loss train: 3.707, val: n/a | iter time: 4320.84 ms\n",
            "Epoch 1 | iter 60 step 7 | loss train: 3.652, val: n/a | iter time: 2557.47 ms\n",
            "Epoch 1 | iter 61 step 7 | loss train: 3.609, val: n/a | iter time: 3524.27 ms\n",
            "Epoch 1 | iter 62 step 7 | loss train: 3.681, val: n/a | iter time: 1927.83 ms\n",
            "Epoch 1 | iter 63 step 7 | loss train: 3.813, val: n/a | iter time: 2958.07 ms\n",
            "Epoch 1 | iter 64 step 8 | loss train: 3.957, val: n/a | iter time: 2188.81 ms (step)\n",
            "Epoch 1 | iter 65 step 8 | loss train: 3.510, val: n/a | iter time: 5177.00 ms\n",
            "Epoch 1 | iter 66 step 8 | loss train: 3.372, val: n/a | iter time: 4558.88 ms\n",
            "Epoch 1 | iter 67 step 8 | loss train: 3.543, val: n/a | iter time: 3136.68 ms\n",
            "Epoch 1 | iter 68 step 8 | loss train: 3.484, val: n/a | iter time: 2574.98 ms\n",
            "Epoch 1 | iter 69 step 8 | loss train: 3.752, val: n/a | iter time: 1288.06 ms\n",
            "Epoch 1 | iter 70 step 8 | loss train: 3.421, val: n/a | iter time: 3608.94 ms\n",
            "Epoch 1 | iter 71 step 8 | loss train: 3.517, val: n/a | iter time: 1661.51 ms\n",
            "Epoch 1 | iter 72 step 9 | loss train: 3.577, val: n/a | iter time: 2938.26 ms (step)\n",
            "Epoch 1 | iter 73 step 9 | loss train: 3.880, val: n/a | iter time: 2505.62 ms\n",
            "Epoch 1 | iter 74 step 9 | loss train: 3.953, val: n/a | iter time: 2612.92 ms\n",
            "Epoch 1 | iter 75 step 9 | loss train: 3.778, val: n/a | iter time: 1692.82 ms\n",
            "Epoch 1 | iter 76 step 9 | loss train: 3.770, val: n/a | iter time: 1892.93 ms\n",
            "Epoch 1 | iter 77 step 9 | loss train: 3.610, val: n/a | iter time: 2923.56 ms\n",
            "Epoch 1 | iter 78 step 9 | loss train: 3.561, val: n/a | iter time: 4020.31 ms\n",
            "Epoch 1 | iter 79 step 9 | loss train: 3.866, val: n/a | iter time: 878.38 ms\n",
            "Epoch 1 | iter 80 step 10 | loss train: 3.800, val: n/a | iter time: 2637.19 ms (step)\n",
            "Epoch 1 | iter 81 step 10 | loss train: 3.632, val: n/a | iter time: 2337.28 ms\n",
            "Epoch 1 | iter 82 step 10 | loss train: 3.663, val: n/a | iter time: 2353.70 ms\n",
            "Epoch 1 | iter 83 step 10 | loss train: 3.745, val: n/a | iter time: 1880.31 ms\n",
            "Epoch 1 | iter 84 step 10 | loss train: 3.593, val: n/a | iter time: 4138.46 ms\n",
            "Epoch 1 | iter 85 step 10 | loss train: 3.515, val: n/a | iter time: 5498.51 ms\n",
            "Epoch 1 | iter 86 step 10 | loss train: 3.506, val: n/a | iter time: 3319.40 ms\n",
            "Epoch 1 | iter 87 step 10 | loss train: 3.056, val: n/a | iter time: 3295.30 ms\n",
            "Epoch 1 | iter 88 step 11 | loss train: 3.170, val: n/a | iter time: 2936.03 ms (step)\n",
            "Epoch 1 | iter 89 step 11 | loss train: 3.307, val: n/a | iter time: 1248.69 ms\n",
            "Epoch 1 | iter 90 step 11 | loss train: 3.321, val: n/a | iter time: 922.59 ms\n",
            "Epoch 1 | iter 91 step 11 | loss train: 3.412, val: n/a | iter time: 1244.60 ms\n",
            "Epoch 1 | iter 92 step 11 | loss train: 3.470, val: n/a | iter time: 1928.45 ms\n",
            "Epoch 1 | iter 93 step 11 | loss train: 3.512, val: n/a | iter time: 4321.48 ms\n",
            "Epoch 1 | iter 94 step 11 | loss train: 3.568, val: n/a | iter time: 3584.04 ms\n",
            "Epoch 1 | iter 95 step 11 | loss train: 3.716, val: n/a | iter time: 2618.61 ms\n",
            "Epoch 1 | iter 96 step 12 | loss train: 3.842, val: n/a | iter time: 1577.83 ms (step)\n",
            "Epoch 1 | iter 97 step 12 | loss train: 3.506, val: n/a | iter time: 2614.41 ms\n",
            "Epoch 1 | iter 98 step 12 | loss train: 3.360, val: n/a | iter time: 2920.39 ms\n",
            "Epoch 1 | iter 99 step 12 | loss train: 3.128, val: n/a | iter time: 1964.58 ms\n",
            "Epoch 1 | iter 100 step 12 | loss train: 3.312, val: n/a | iter time: 1234.52 ms\n",
            "Epoch 1 | iter 101 step 12 | loss train: 3.237, val: n/a | iter time: 3875.33 ms\n",
            "Epoch 1 | iter 102 step 12 | loss train: 3.266, val: n/a | iter time: 2357.28 ms\n",
            "Epoch 1 | iter 103 step 12 | loss train: 3.305, val: n/a | iter time: 1605.20 ms\n",
            "Epoch 1 | iter 104 step 13 | loss train: 3.063, val: n/a | iter time: 2915.50 ms (step)\n",
            "Epoch 1 | iter 105 step 13 | loss train: 3.360, val: n/a | iter time: 1230.03 ms\n",
            "Epoch 1 | iter 106 step 13 | loss train: 3.565, val: n/a | iter time: 1340.56 ms\n",
            "Epoch 1 | iter 107 step 13 | loss train: 3.507, val: n/a | iter time: 3881.12 ms\n",
            "Epoch 1 | iter 108 step 13 | loss train: 3.510, val: n/a | iter time: 1220.96 ms\n",
            "Epoch 1 | iter 109 step 13 | loss train: 3.586, val: n/a | iter time: 1658.92 ms\n",
            "Epoch 1 | iter 110 step 13 | loss train: 3.502, val: n/a | iter time: 2910.13 ms\n",
            "Epoch 1 | iter 111 step 13 | loss train: 3.329, val: n/a | iter time: 3017.36 ms\n",
            "Epoch 1 | iter 112 step 14 | loss train: 3.335, val: n/a | iter time: 3568.80 ms (step)\n",
            "Epoch 1 | iter 113 step 14 | loss train: 3.285, val: n/a | iter time: 982.24 ms\n",
            "Epoch 1 | iter 114 step 14 | loss train: 3.203, val: n/a | iter time: 1596.91 ms\n",
            "Epoch 1 | iter 115 step 14 | loss train: 3.089, val: n/a | iter time: 5467.28 ms\n",
            "Epoch 1 | iter 116 step 14 | loss train: 2.834, val: n/a | iter time: 4027.77 ms\n",
            "Epoch 1 | iter 117 step 14 | loss train: 2.809, val: n/a | iter time: 1963.16 ms\n",
            "Epoch 1 | iter 118 step 14 | loss train: 2.855, val: n/a | iter time: 1979.58 ms\n",
            "Epoch 1 | iter 119 step 14 | loss train: 2.921, val: n/a | iter time: 1903.05 ms\n",
            "Epoch 1 | iter 120 step 15 | loss train: 2.913, val: n/a | iter time: 4033.59 ms (step)\n",
            "Epoch 1 | iter 121 step 15 | loss train: 2.919, val: n/a | iter time: 904.78 ms\n",
            "Epoch 1 | iter 122 step 15 | loss train: 2.825, val: n/a | iter time: 3672.71 ms\n",
            "Epoch 1 | iter 123 step 15 | loss train: 2.915, val: n/a | iter time: 5071.62 ms\n",
            "Epoch 1 | iter 124 step 15 | loss train: 3.024, val: n/a | iter time: 1963.67 ms\n",
            "Epoch 1 | iter 125 step 15 | loss train: 2.930, val: n/a | iter time: 2227.00 ms\n",
            "Epoch 1 | iter 126 step 15 | loss train: 2.912, val: n/a | iter time: 2336.66 ms\n",
            "Epoch 1 | iter 127 step 15 | loss train: 2.729, val: n/a | iter time: 3223.66 ms\n",
            "Epoch 1 | iter 128 step 16 | loss train: 2.925, val: n/a | iter time: 5503.68 ms (step)\n",
            "Epoch 1 | iter 129 step 16 | loss train: 2.701, val: n/a | iter time: 2917.52 ms\n",
            "Epoch 1 | iter 130 step 16 | loss train: 2.721, val: n/a | iter time: 3661.05 ms\n",
            "Epoch 1 | iter 131 step 16 | loss train: 2.680, val: n/a | iter time: 2845.08 ms\n",
            "Epoch 1 | iter 132 step 16 | loss train: 2.577, val: n/a | iter time: 3264.14 ms\n",
            "Epoch 1 | iter 133 step 16 | loss train: 2.569, val: n/a | iter time: 2266.26 ms\n",
            "Epoch 1 | iter 134 step 16 | loss train: 2.652, val: n/a | iter time: 2602.32 ms\n",
            "Epoch 1 | iter 135 step 16 | loss train: 2.748, val: n/a | iter time: 1936.16 ms\n",
            "Epoch 1 | iter 136 step 17 | loss train: 2.527, val: n/a | iter time: 1973.42 ms (step)\n",
            "Epoch 1 | iter 137 step 17 | loss train: 2.620, val: n/a | iter time: 976.58 ms\n",
            "Epoch 1 | iter 138 step 17 | loss train: 2.455, val: n/a | iter time: 5394.75 ms\n",
            "Epoch 1 | iter 139 step 17 | loss train: 2.378, val: n/a | iter time: 2215.31 ms\n",
            "Epoch 1 | iter 140 step 17 | loss train: 2.440, val: n/a | iter time: 2798.07 ms\n",
            "Epoch 1 | iter 141 step 17 | loss train: 2.342, val: n/a | iter time: 2635.62 ms\n",
            "Epoch 1 | iter 142 step 17 | loss train: 2.232, val: n/a | iter time: 1230.33 ms\n",
            "Epoch 1 | iter 143 step 17 | loss train: 2.187, val: n/a | iter time: 1603.13 ms\n",
            "Epoch 1 | iter 144 step 18 | loss train: 2.208, val: n/a | iter time: 975.69 ms (step)\n",
            "Epoch 1 | iter 145 step 18 | loss train: 2.217, val: n/a | iter time: 3923.13 ms\n",
            "Epoch 1 | iter 146 step 18 | loss train: 2.360, val: n/a | iter time: 2917.51 ms\n",
            "Epoch 1 | iter 147 step 18 | loss train: 2.573, val: n/a | iter time: 987.01 ms\n",
            "Epoch 1 | iter 148 step 18 | loss train: 2.520, val: n/a | iter time: 1997.38 ms\n",
            "Epoch 1 | iter 149 step 18 | loss train: 2.515, val: n/a | iter time: 3206.82 ms\n",
            "Epoch 1 | iter 150 step 18 | loss train: 2.448, val: n/a | iter time: 3368.48 ms\n",
            "Epoch 1 | iter 151 step 18 | loss train: 2.523, val: n/a | iter time: 1901.58 ms\n",
            "Epoch 1 | iter 152 step 19 | loss train: 2.544, val: n/a | iter time: 877.22 ms (step)\n",
            "Epoch 1 | iter 153 step 19 | loss train: 2.534, val: n/a | iter time: 1329.60 ms\n",
            "Epoch 1 | iter 154 step 19 | loss train: 2.462, val: n/a | iter time: 3244.32 ms\n",
            "Epoch 1 | iter 155 step 19 | loss train: 2.316, val: n/a | iter time: 2545.64 ms\n",
            "Epoch 1 | iter 156 step 19 | loss train: 2.296, val: n/a | iter time: 4414.54 ms\n",
            "Epoch 1 | iter 157 step 19 | loss train: 2.298, val: n/a | iter time: 2370.03 ms\n",
            "Epoch 1 | iter 158 step 19 | loss train: 2.261, val: n/a | iter time: 4015.23 ms\n",
            "Epoch 1 | iter 159 step 19 | loss train: 2.197, val: n/a | iter time: 1573.87 ms\n",
            "Epoch 1 | iter 160 step 20 | loss train: 2.137, val: n/a | iter time: 1240.15 ms (step)\n",
            "Epoch 1 | iter 161 step 20 | loss train: 2.054, val: n/a | iter time: 4564.39 ms\n",
            "Epoch 1 | iter 162 step 20 | loss train: 2.044, val: n/a | iter time: 5423.34 ms\n",
            "Epoch 1 | iter 163 step 20 | loss train: 2.013, val: n/a | iter time: 5585.34 ms\n",
            "Epoch 1 | iter 164 step 20 | loss train: 1.959, val: n/a | iter time: 2320.56 ms\n",
            "Epoch 1 | iter 165 step 20 | loss train: 2.162, val: n/a | iter time: 685.35 ms\n",
            "Epoch 1 | iter 166 step 20 | loss train: 2.250, val: n/a | iter time: 1219.11 ms\n",
            "Epoch 1 | iter 167 step 20 | loss train: 2.179, val: n/a | iter time: 1558.46 ms\n",
            "Epoch 1 | iter 168 step 21 | loss train: 2.185, val: n/a | iter time: 3768.62 ms (step)\n",
            "Epoch 1 | iter 169 step 21 | loss train: 2.102, val: n/a | iter time: 4052.35 ms\n",
            "Epoch 1 | iter 170 step 21 | loss train: 2.056, val: n/a | iter time: 3578.11 ms\n",
            "Epoch 1 | iter 171 step 21 | loss train: 2.184, val: n/a | iter time: 983.94 ms\n",
            "Epoch 1 | iter 172 step 21 | loss train: 2.195, val: n/a | iter time: 1317.48 ms\n",
            "Epoch 1 | iter 173 step 21 | loss train: 1.948, val: n/a | iter time: 3948.45 ms\n",
            "Epoch 1 | iter 174 step 21 | loss train: 1.873, val: n/a | iter time: 2620.51 ms\n",
            "Epoch 1 | iter 175 step 21 | loss train: 1.878, val: n/a | iter time: 2623.16 ms\n",
            "Epoch 1 | iter 176 step 22 | loss train: 1.780, val: n/a | iter time: 4856.63 ms (step)\n",
            "Epoch 1 | iter 177 step 22 | loss train: 1.761, val: n/a | iter time: 3692.08 ms\n",
            "Epoch 1 | iter 178 step 22 | loss train: 1.736, val: n/a | iter time: 3663.83 ms\n",
            "Epoch 1 | iter 179 step 22 | loss train: 1.673, val: n/a | iter time: 899.16 ms\n",
            "Epoch 1 | iter 180 step 22 | loss train: 1.758, val: n/a | iter time: 2284.23 ms\n",
            "Epoch 1 | iter 181 step 22 | loss train: 1.815, val: n/a | iter time: 5338.17 ms\n",
            "Epoch 1 | iter 182 step 22 | loss train: 1.870, val: n/a | iter time: 1226.17 ms\n",
            "Epoch 1 | iter 183 step 22 | loss train: 1.807, val: n/a | iter time: 3213.69 ms\n",
            "Epoch 1 | iter 184 step 23 | loss train: 1.826, val: n/a | iter time: 1884.17 ms (step)\n",
            "Epoch 1 | iter 185 step 23 | loss train: 1.923, val: n/a | iter time: 1308.92 ms\n",
            "Epoch 1 | iter 186 step 23 | loss train: 1.951, val: n/a | iter time: 2983.10 ms\n",
            "Epoch 1 | iter 187 step 23 | loss train: 1.932, val: n/a | iter time: 1585.78 ms\n",
            "Epoch 1 | iter 188 step 23 | loss train: 1.760, val: n/a | iter time: 1868.64 ms\n",
            "Epoch 1 | iter 189 step 23 | loss train: 1.723, val: n/a | iter time: 2622.93 ms\n",
            "Epoch 1 | iter 190 step 23 | loss train: 1.654, val: n/a | iter time: 3540.66 ms\n",
            "Epoch 1 | iter 191 step 23 | loss train: 1.664, val: n/a | iter time: 4012.62 ms\n",
            "Epoch 1 | iter 192 step 24 | loss train: 1.615, val: n/a | iter time: 4372.57 ms (step)\n",
            "Epoch 1 | iter 193 step 24 | loss train: 1.555, val: n/a | iter time: 2912.62 ms\n",
            "Epoch 1 | iter 194 step 24 | loss train: 1.611, val: n/a | iter time: 1594.72 ms\n",
            "Epoch 1 | iter 195 step 24 | loss train: 1.478, val: n/a | iter time: 2660.06 ms\n",
            "Epoch 1 | iter 196 step 24 | loss train: 1.499, val: n/a | iter time: 3553.53 ms\n",
            "Epoch 1 | iter 197 step 24 | loss train: 1.401, val: n/a | iter time: 5343.96 ms\n",
            "Epoch 1 | iter 198 step 24 | loss train: 1.426, val: n/a | iter time: 1949.38 ms\n",
            "Epoch 1 | iter 199 step 24 | loss train: 1.366, val: n/a | iter time: 3641.24 ms\n",
            "Epoch 1 | iter 200 step 25 | loss train: 1.319, val: n/a | iter time: 3587.26 ms (step)\n",
            "Epoch 1 | iter 201 step 25 | loss train: 1.359, val: n/a | iter time: 1886.25 ms\n",
            "Epoch 1 | iter 202 step 25 | loss train: 1.348, val: n/a | iter time: 1193.07 ms\n",
            "Epoch 1 | iter 203 step 25 | loss train: 1.491, val: n/a | iter time: 1868.53 ms\n",
            "Epoch 1 | iter 204 step 25 | loss train: 1.576, val: n/a | iter time: 1675.91 ms\n",
            "Epoch 1 | iter 205 step 25 | loss train: 1.650, val: n/a | iter time: 5223.68 ms\n",
            "Epoch 1 | iter 206 step 25 | loss train: 1.604, val: n/a | iter time: 3613.38 ms\n",
            "Epoch 1 | iter 207 step 25 | loss train: 1.616, val: n/a | iter time: 2994.69 ms\n",
            "Epoch 1 | iter 208 step 26 | loss train: 1.665, val: n/a | iter time: 1335.89 ms (step)\n",
            "Epoch 1 | iter 209 step 26 | loss train: 1.657, val: n/a | iter time: 1656.02 ms\n",
            "Epoch 1 | iter 210 step 26 | loss train: 1.623, val: n/a | iter time: 3230.78 ms\n",
            "Epoch 1 | iter 211 step 26 | loss train: 1.530, val: n/a | iter time: 3881.58 ms\n",
            "Epoch 1 | iter 212 step 26 | loss train: 1.426, val: n/a | iter time: 1954.42 ms\n",
            "Epoch 1 | iter 213 step 26 | loss train: 1.459, val: n/a | iter time: 1202.94 ms\n",
            "Epoch 1 | iter 214 step 26 | loss train: 1.475, val: n/a | iter time: 1993.53 ms\n",
            "Epoch 1 | iter 215 step 26 | loss train: 1.413, val: n/a | iter time: 5391.94 ms\n",
            "Epoch 1 | iter 216 step 27 | loss train: 1.472, val: n/a | iter time: 3170.49 ms (step)\n",
            "Epoch 1 | iter 217 step 27 | loss train: 1.463, val: n/a | iter time: 1594.49 ms\n",
            "Epoch 1 | iter 218 step 27 | loss train: 1.479, val: n/a | iter time: 891.22 ms\n",
            "Epoch 1 | iter 219 step 27 | loss train: 1.512, val: n/a | iter time: 2585.38 ms\n",
            "Epoch 1 | iter 220 step 27 | loss train: 1.574, val: n/a | iter time: 1581.72 ms\n",
            "Epoch 1 | iter 221 step 27 | loss train: 1.549, val: n/a | iter time: 4267.64 ms\n",
            "Epoch 1 | iter 222 step 27 | loss train: 1.512, val: n/a | iter time: 2993.92 ms\n",
            "Epoch 1 | iter 223 step 27 | loss train: 1.604, val: n/a | iter time: 3291.66 ms\n",
            "Epoch 1 | iter 224 step 28 | loss train: 1.515, val: n/a | iter time: 4612.66 ms (step)\n",
            "Epoch 1 | iter 225 step 28 | loss train: 1.448, val: n/a | iter time: 3202.84 ms\n",
            "Epoch 1 | iter 226 step 28 | loss train: 1.351, val: n/a | iter time: 5368.03 ms\n",
            "Epoch 1 | iter 227 step 28 | loss train: 1.353, val: n/a | iter time: 2619.43 ms\n",
            "Epoch 1 | iter 228 step 28 | loss train: 1.316, val: n/a | iter time: 887.82 ms\n",
            "Epoch 1 | iter 229 step 28 | loss train: 1.274, val: n/a | iter time: 4241.16 ms\n",
            "Epoch 1 | iter 230 step 28 | loss train: 1.244, val: n/a | iter time: 5029.94 ms\n",
            "Epoch 1 | iter 231 step 28 | loss train: 1.238, val: n/a | iter time: 2293.78 ms\n",
            "Epoch 1 | iter 232 step 29 | loss train: 1.300, val: n/a | iter time: 1316.51 ms (step)\n",
            "Epoch 1 | iter 233 step 29 | loss train: 1.351, val: n/a | iter time: 3941.67 ms\n",
            "Epoch 1 | iter 234 step 29 | loss train: 1.392, val: n/a | iter time: 2607.83 ms\n",
            "Epoch 1 | iter 235 step 29 | loss train: 1.353, val: n/a | iter time: 1861.41 ms\n",
            "Epoch 1 | iter 236 step 29 | loss train: 1.303, val: n/a | iter time: 4022.57 ms\n",
            "Epoch 1 | iter 237 step 29 | loss train: 1.334, val: n/a | iter time: 4962.81 ms\n",
            "Epoch 1 | iter 238 step 29 | loss train: 1.335, val: n/a | iter time: 4268.78 ms\n",
            "Epoch 1 | iter 239 step 29 | loss train: 1.374, val: n/a | iter time: 3167.86 ms\n",
            "Epoch 1 | iter 240 step 30 | loss train: 1.423, val: n/a | iter time: 1218.24 ms (step)\n",
            "Epoch 1 | iter 241 step 30 | loss train: 1.376, val: n/a | iter time: 1585.17 ms\n",
            "Epoch 1 | iter 242 step 30 | loss train: 1.394, val: n/a | iter time: 4341.90 ms\n",
            "Epoch 1 | iter 243 step 30 | loss train: 1.418, val: n/a | iter time: 1654.72 ms\n",
            "Epoch 1 | iter 244 step 30 | loss train: 1.471, val: n/a | iter time: 981.14 ms\n",
            "Epoch 1 | iter 245 step 30 | loss train: 1.449, val: n/a | iter time: 3259.17 ms\n",
            "Epoch 1 | iter 246 step 30 | loss train: 1.491, val: n/a | iter time: 1320.21 ms\n",
            "Epoch 1 | iter 247 step 30 | loss train: 1.428, val: n/a | iter time: 3944.30 ms\n",
            "Epoch 1 | iter 248 step 31 | loss train: 1.324, val: n/a | iter time: 860.11 ms (step)\n",
            "Epoch 1 | iter 249 step 31 | loss train: 1.330, val: n/a | iter time: 5379.95 ms\n",
            "Epoch 1 | iter 250 step 31 | loss train: 1.265, val: n/a | iter time: 1986.93 ms\n",
            "Epoch 1 | iter 251 step 31 | loss train: 1.224, val: n/a | iter time: 887.01 ms\n",
            "Epoch 1 | iter 252 step 31 | loss train: 1.211, val: n/a | iter time: 3186.76 ms\n",
            "Epoch 1 | iter 253 step 31 | loss train: 1.233, val: n/a | iter time: 1898.22 ms\n",
            "Epoch 1 | iter 254 step 31 | loss train: 1.171, val: n/a | iter time: 4363.96 ms\n",
            "Epoch 1 | iter 255 step 31 | loss train: 1.165, val: n/a | iter time: 3620.45 ms\n",
            "Epoch 1 | iter 256 step 32 | loss train: 1.163, val: n/a | iter time: 4783.73 ms (step)\n",
            "Epoch 1 | iter 257 step 32 | loss train: 1.145, val: n/a | iter time: 3193.30 ms\n",
            "Epoch 1 | iter 258 step 32 | loss train: 1.181, val: n/a | iter time: 1958.15 ms\n",
            "Epoch 1 | iter 259 step 32 | loss train: 1.153, val: n/a | iter time: 5050.93 ms\n",
            "Epoch 1 | iter 260 step 32 | loss train: 1.165, val: n/a | iter time: 4534.73 ms\n",
            "Epoch 1 | iter 261 step 32 | loss train: 1.100, val: n/a | iter time: 3640.39 ms\n",
            "Epoch 1 | iter 262 step 32 | loss train: 1.123, val: n/a | iter time: 1882.55 ms\n",
            "Epoch 1 | iter 263 step 32 | loss train: 1.139, val: n/a | iter time: 1667.97 ms\n",
            "Epoch 1 | iter 264 step 33 | loss train: 1.196, val: n/a | iter time: 1222.80 ms (step)\n",
            "Epoch 1 | iter 265 step 33 | loss train: 1.200, val: n/a | iter time: 4409.41 ms\n",
            "Epoch 1 | iter 266 step 33 | loss train: 1.201, val: n/a | iter time: 4254.87 ms\n",
            "Epoch 1 | iter 267 step 33 | loss train: 1.235, val: n/a | iter time: 2229.12 ms\n",
            "Epoch 1 | iter 268 step 33 | loss train: 1.224, val: n/a | iter time: 3296.93 ms\n",
            "Epoch 1 | iter 269 step 33 | loss train: 1.262, val: n/a | iter time: 1312.80 ms\n",
            "Epoch 1 | iter 270 step 33 | loss train: 1.294, val: n/a | iter time: 3600.04 ms\n",
            "Epoch 1 | iter 271 step 33 | loss train: 1.288, val: n/a | iter time: 4940.02 ms\n",
            "Epoch 1 | iter 272 step 34 | loss train: 1.238, val: n/a | iter time: 3547.03 ms (step)\n",
            "Epoch 1 | iter 273 step 34 | loss train: 1.243, val: n/a | iter time: 4402.61 ms\n",
            "Epoch 1 | iter 274 step 34 | loss train: 1.168, val: n/a | iter time: 3259.84 ms\n",
            "Epoch 1 | iter 275 step 34 | loss train: 1.147, val: n/a | iter time: 5309.79 ms\n",
            "Epoch 1 | iter 276 step 34 | loss train: 1.098, val: n/a | iter time: 3275.57 ms\n",
            "Epoch 1 | iter 277 step 34 | loss train: 1.119, val: n/a | iter time: 1576.71 ms\n",
            "Epoch 1 | iter 278 step 34 | loss train: 1.103, val: n/a | iter time: 984.35 ms\n",
            "Epoch 1 | iter 279 step 34 | loss train: 1.167, val: n/a | iter time: 893.40 ms\n",
            "Epoch 1 | iter 280 step 35 | loss train: 1.157, val: n/a | iter time: 4948.75 ms (step)\n",
            "Epoch 1 | iter 281 step 35 | loss train: 1.162, val: n/a | iter time: 2228.42 ms\n",
            "Epoch 1 | iter 282 step 35 | loss train: 1.178, val: n/a | iter time: 5402.99 ms\n",
            "Epoch 1 | iter 283 step 35 | loss train: 1.200, val: n/a | iter time: 4983.40 ms\n",
            "Epoch 1 | iter 284 step 35 | loss train: 1.195, val: n/a | iter time: 871.54 ms\n",
            "Epoch 1 | iter 285 step 35 | loss train: 1.187, val: n/a | iter time: 4242.29 ms\n",
            "Epoch 1 | iter 286 step 35 | loss train: 1.175, val: n/a | iter time: 3620.68 ms\n",
            "Epoch 1 | iter 287 step 35 | loss train: 1.145, val: n/a | iter time: 3551.61 ms\n",
            "Epoch 1 | iter 288 step 36 | loss train: 1.167, val: n/a | iter time: 3913.16 ms (step)\n",
            "Epoch 1 | iter 289 step 36 | loss train: 1.206, val: n/a | iter time: 1678.35 ms\n",
            "Epoch 1 | iter 290 step 36 | loss train: 1.270, val: n/a | iter time: 1957.13 ms\n",
            "Epoch 1 | iter 291 step 36 | loss train: 1.227, val: n/a | iter time: 2986.93 ms\n",
            "Epoch 1 | iter 292 step 36 | loss train: 1.239, val: n/a | iter time: 2217.90 ms\n",
            "Epoch 1 | iter 293 step 36 | loss train: 1.287, val: n/a | iter time: 3613.77 ms\n",
            "Epoch 1 | iter 294 step 36 | loss train: 1.287, val: n/a | iter time: 1343.80 ms\n",
            "Epoch 1 | iter 295 step 36 | loss train: 1.238, val: n/a | iter time: 3758.49 ms\n",
            "Epoch 1 | iter 296 step 37 | loss train: 1.211, val: n/a | iter time: 1314.16 ms (step)\n",
            "Epoch 1 | iter 297 step 37 | loss train: 1.181, val: n/a | iter time: 2637.10 ms\n",
            "Epoch 1 | iter 298 step 37 | loss train: 1.119, val: n/a | iter time: 3654.45 ms\n",
            "Epoch 1 | iter 299 step 37 | loss train: 1.129, val: n/a | iter time: 2534.13 ms\n",
            "Epoch 1 | iter 300 step 37 | loss train: 1.143, val: n/a | iter time: 3616.18 ms\n",
            "Epoch 1 | iter 301 step 37 | loss train: 1.089, val: n/a | iter time: 1338.77 ms\n",
            "Epoch 1 | iter 302 step 37 | loss train: 1.106, val: n/a | iter time: 1609.71 ms\n",
            "Epoch 1 | iter 303 step 37 | loss train: 1.124, val: n/a | iter time: 3170.09 ms\n",
            "Epoch 1 | iter 304 step 38 | loss train: 1.175, val: n/a | iter time: 2930.95 ms (step)\n",
            "Epoch 1 | iter 305 step 38 | loss train: 1.146, val: n/a | iter time: 2960.99 ms\n",
            "Epoch 1 | iter 306 step 38 | loss train: 1.185, val: n/a | iter time: 2536.80 ms\n",
            "Epoch 1 | iter 307 step 38 | loss train: 1.236, val: n/a | iter time: 4011.01 ms\n",
            "Epoch 1 | iter 308 step 38 | loss train: 1.220, val: n/a | iter time: 875.02 ms\n",
            "Epoch 1 | iter 309 step 38 | loss train: 1.188, val: n/a | iter time: 4283.27 ms\n",
            "Epoch 1 | iter 310 step 38 | loss train: 1.136, val: n/a | iter time: 1612.42 ms\n",
            "Epoch 1 | iter 311 step 38 | loss train: 1.171, val: n/a | iter time: 2592.53 ms\n",
            "Epoch 1 | iter 312 step 39 | loss train: 1.116, val: n/a | iter time: 5047.49 ms (step)\n",
            "Epoch 1 | iter 313 step 39 | loss train: 1.091, val: n/a | iter time: 897.45 ms\n",
            "Epoch 1 | iter 314 step 39 | loss train: 1.108, val: n/a | iter time: 2614.83 ms\n",
            "Epoch 1 | iter 315 step 39 | loss train: 1.097, val: n/a | iter time: 3610.76 ms\n",
            "Epoch 1 | iter 316 step 39 | loss train: 1.090, val: n/a | iter time: 1319.99 ms\n",
            "Epoch 1 | iter 317 step 39 | loss train: 1.106, val: n/a | iter time: 2651.48 ms\n",
            "Epoch 1 | iter 318 step 39 | loss train: 1.161, val: n/a | iter time: 3039.39 ms\n",
            "Epoch 1 | iter 319 step 39 | loss train: 1.098, val: n/a | iter time: 882.55 ms\n",
            "Epoch 1 | iter 320 step 40 | loss train: 1.133, val: n/a | iter time: 988.39 ms (step)\n",
            "Epoch 1 | iter 321 step 40 | loss train: 1.152, val: n/a | iter time: 899.70 ms\n",
            "Epoch 1 | iter 322 step 40 | loss train: 1.128, val: n/a | iter time: 3072.25 ms\n",
            "Epoch 1 | iter 323 step 40 | loss train: 1.097, val: n/a | iter time: 1230.01 ms\n",
            "Epoch 1 | iter 324 step 40 | loss train: 1.082, val: n/a | iter time: 1008.05 ms\n",
            "Epoch 1 | iter 325 step 40 | loss train: 1.080, val: n/a | iter time: 3302.87 ms\n",
            "Epoch 1 | iter 326 step 40 | loss train: 1.062, val: n/a | iter time: 3713.65 ms\n",
            "Epoch 1 | iter 327 step 40 | loss train: 1.074, val: n/a | iter time: 3220.23 ms\n",
            "Epoch 1 | iter 328 step 41 | loss train: 1.077, val: n/a | iter time: 5427.04 ms (step)\n",
            "Epoch 1 | iter 329 step 41 | loss train: 1.087, val: n/a | iter time: 2327.57 ms\n",
            "Epoch 1 | iter 330 step 41 | loss train: 1.100, val: n/a | iter time: 2228.64 ms\n",
            "Epoch 1 | iter 331 step 41 | loss train: 1.132, val: n/a | iter time: 995.77 ms\n",
            "Epoch 1 | iter 332 step 41 | loss train: 1.178, val: n/a | iter time: 1899.36 ms\n",
            "Epoch 1 | iter 333 step 41 | loss train: 1.169, val: n/a | iter time: 998.04 ms\n",
            "Epoch 1 | iter 334 step 41 | loss train: 1.166, val: n/a | iter time: 3002.01 ms\n",
            "Epoch 1 | iter 335 step 41 | loss train: 1.155, val: n/a | iter time: 1685.56 ms\n",
            "Epoch 1 | iter 336 step 42 | loss train: 1.117, val: n/a | iter time: 1237.64 ms (step)\n",
            "Epoch 1 | iter 337 step 42 | loss train: 1.124, val: n/a | iter time: 1676.17 ms\n",
            "Epoch 1 | iter 338 step 42 | loss train: 1.063, val: n/a | iter time: 3190.19 ms\n",
            "Epoch 1 | iter 339 step 42 | loss train: 1.078, val: n/a | iter time: 4577.10 ms\n",
            "Epoch 1 | iter 340 step 42 | loss train: 1.085, val: n/a | iter time: 1319.83 ms\n",
            "Epoch 1 | iter 341 step 42 | loss train: 1.080, val: n/a | iter time: 1308.95 ms\n",
            "Epoch 1 | iter 342 step 42 | loss train: 1.084, val: n/a | iter time: 2925.11 ms\n",
            "Epoch 1 | iter 343 step 42 | loss train: 1.116, val: n/a | iter time: 4024.10 ms\n",
            "Epoch 1 | iter 344 step 43 | loss train: 1.169, val: n/a | iter time: 1223.04 ms (step)\n",
            "Epoch 1 | iter 345 step 43 | loss train: 1.155, val: n/a | iter time: 2186.55 ms\n",
            "Epoch 1 | iter 346 step 43 | loss train: 1.213, val: n/a | iter time: 2631.48 ms\n",
            "Epoch 1 | iter 347 step 43 | loss train: 1.186, val: n/a | iter time: 1658.05 ms\n",
            "Epoch 1 | iter 348 step 43 | loss train: 1.153, val: n/a | iter time: 4668.90 ms\n",
            "Epoch 1 | iter 349 step 43 | loss train: 1.198, val: n/a | iter time: 4402.04 ms\n",
            "Epoch 1 | iter 350 step 43 | loss train: 1.153, val: n/a | iter time: 1667.90 ms\n",
            "Epoch 1 | iter 351 step 43 | loss train: 1.097, val: n/a | iter time: 3671.27 ms\n",
            "Epoch 1 | iter 352 step 44 | loss train: 1.050, val: n/a | iter time: 3296.24 ms (step)\n",
            "Epoch 1 | iter 353 step 44 | loss train: 1.073, val: n/a | iter time: 3958.62 ms\n",
            "Epoch 1 | iter 354 step 44 | loss train: 1.087, val: n/a | iter time: 2908.73 ms\n",
            "Epoch 1 | iter 355 step 44 | loss train: 1.127, val: n/a | iter time: 3004.01 ms\n",
            "Epoch 1 | iter 356 step 44 | loss train: 1.122, val: n/a | iter time: 3301.23 ms\n",
            "Epoch 1 | iter 357 step 44 | loss train: 1.157, val: n/a | iter time: 1580.18 ms\n",
            "Epoch 1 | iter 358 step 44 | loss train: 1.203, val: n/a | iter time: 1877.14 ms\n",
            "Epoch 1 | iter 359 step 44 | loss train: 1.270, val: n/a | iter time: 2357.04 ms\n",
            "Epoch 1 | iter 360 step 45 | loss train: 1.277, val: n/a | iter time: 870.76 ms (step)\n",
            "Epoch 1 | iter 361 step 45 | loss train: 1.298, val: n/a | iter time: 3256.84 ms\n",
            "Epoch 1 | iter 362 step 45 | loss train: 1.269, val: n/a | iter time: 1942.28 ms\n",
            "Epoch 1 | iter 363 step 45 | loss train: 1.220, val: n/a | iter time: 4345.88 ms\n",
            "Epoch 1 | iter 364 step 45 | loss train: 1.230, val: n/a | iter time: 4349.67 ms\n",
            "Epoch 1 | iter 365 step 45 | loss train: 1.164, val: n/a | iter time: 2990.72 ms\n",
            "Epoch 1 | iter 366 step 45 | loss train: 1.171, val: n/a | iter time: 981.81 ms\n",
            "Epoch 1 | iter 367 step 45 | loss train: 1.099, val: n/a | iter time: 1882.72 ms\n",
            "Epoch 1 | iter 368 step 46 | loss train: 1.105, val: n/a | iter time: 4493.50 ms (step)\n",
            "Epoch 1 | iter 369 step 46 | loss train: 1.071, val: n/a | iter time: 3180.74 ms\n",
            "Epoch 1 | iter 370 step 46 | loss train: 1.092, val: n/a | iter time: 889.73 ms\n",
            "Epoch 1 | iter 371 step 46 | loss train: 1.092, val: n/a | iter time: 882.06 ms\n",
            "Epoch 1 | iter 372 step 46 | loss train: 1.111, val: n/a | iter time: 3296.36 ms\n",
            "Epoch 1 | iter 373 step 46 | loss train: 1.090, val: n/a | iter time: 3950.37 ms\n",
            "Epoch 1 | iter 374 step 46 | loss train: 1.081, val: n/a | iter time: 1591.24 ms\n",
            "Epoch 1 | iter 375 step 46 | loss train: 1.104, val: n/a | iter time: 5474.71 ms\n",
            "Epoch 1 | iter 376 step 47 | loss train: 1.148, val: n/a | iter time: 5338.69 ms (step)\n",
            "Epoch 1 | iter 377 step 47 | loss train: 1.125, val: n/a | iter time: 2275.18 ms\n",
            "Epoch 1 | iter 378 step 47 | loss train: 1.066, val: n/a | iter time: 3955.25 ms\n",
            "Epoch 1 | iter 379 step 47 | loss train: 1.069, val: n/a | iter time: 1972.94 ms\n",
            "Epoch 1 | iter 380 step 47 | loss train: 1.072, val: n/a | iter time: 3299.38 ms\n",
            "Epoch 1 | iter 381 step 47 | loss train: 1.090, val: n/a | iter time: 692.97 ms\n",
            "Epoch 1 | iter 382 step 47 | loss train: 1.113, val: n/a | iter time: 2211.40 ms\n",
            "Epoch 1 | iter 383 step 47 | loss train: 1.116, val: n/a | iter time: 3731.80 ms\n",
            "Epoch 1 | iter 384 step 48 | loss train: 1.040, val: n/a | iter time: 4274.84 ms (step)\n",
            "Epoch 1 | iter 385 step 48 | loss train: 1.072, val: n/a | iter time: 3191.48 ms\n",
            "Epoch 1 | iter 386 step 48 | loss train: 1.078, val: n/a | iter time: 2521.86 ms\n",
            "Epoch 1 | iter 387 step 48 | loss train: 1.091, val: n/a | iter time: 4309.03 ms\n",
            "Epoch 1 | iter 388 step 48 | loss train: 1.054, val: n/a | iter time: 3618.03 ms\n",
            "Epoch 1 | iter 389 step 48 | loss train: 1.016, val: n/a | iter time: 3550.26 ms\n",
            "Epoch 1 | iter 390 step 48 | loss train: 0.974, val: n/a | iter time: 887.19 ms\n",
            "Epoch 1 | iter 391 step 48 | loss train: 0.985, val: n/a | iter time: 4020.19 ms\n",
            "Epoch 1 | iter 392 step 49 | loss train: 1.067, val: n/a | iter time: 3269.74 ms (step)\n",
            "Epoch 1 | iter 393 step 49 | loss train: 1.058, val: n/a | iter time: 3559.00 ms\n",
            "Epoch 1 | iter 394 step 49 | loss train: 1.099, val: n/a | iter time: 4415.36 ms\n",
            "Epoch 1 | iter 395 step 49 | loss train: 1.045, val: n/a | iter time: 1216.50 ms\n",
            "Epoch 1 | iter 396 step 49 | loss train: 1.071, val: n/a | iter time: 5378.73 ms\n",
            "Epoch 1 | iter 397 step 49 | loss train: 1.116, val: n/a | iter time: 973.96 ms\n",
            "Epoch 1 | iter 398 step 49 | loss train: 1.135, val: n/a | iter time: 3885.67 ms\n",
            "Epoch 1 | iter 399 step 49 | loss train: 1.161, val: n/a | iter time: 4269.91 ms\n",
            "Epoch 1 | iter 400 step 50 | loss train: 1.166, val: n/a | iter time: 1907.12 ms (step)\n",
            "Epoch 1 | iter 401 step 50 | loss train: 1.115, val: n/a | iter time: 969.22 ms\n",
            "Epoch 1 | iter 402 step 50 | loss train: 1.119, val: n/a | iter time: 1645.81 ms\n",
            "Epoch 1 | iter 403 step 50 | loss train: 1.189, val: n/a | iter time: 2672.70 ms\n",
            "Epoch 1 | iter 404 step 50 | loss train: 1.215, val: n/a | iter time: 3336.37 ms\n",
            "Epoch 1 | iter 405 step 50 | loss train: 1.208, val: n/a | iter time: 1668.45 ms\n",
            "Epoch 1 | iter 406 step 50 | loss train: 1.225, val: n/a | iter time: 3648.91 ms\n",
            "Epoch 1 | iter 407 step 50 | loss train: 1.228, val: n/a | iter time: 963.74 ms\n",
            "Epoch 1 | iter 408 step 51 | loss train: 1.139, val: n/a | iter time: 878.44 ms (step)\n",
            "Epoch 1 | iter 409 step 51 | loss train: 1.187, val: n/a | iter time: 4093.31 ms\n",
            "Epoch 1 | iter 410 step 51 | loss train: 1.140, val: n/a | iter time: 975.97 ms\n",
            "Epoch 1 | iter 411 step 51 | loss train: 1.071, val: n/a | iter time: 2244.92 ms\n",
            "Epoch 1 | iter 412 step 51 | loss train: 1.021, val: n/a | iter time: 1677.74 ms\n",
            "Epoch 1 | iter 413 step 51 | loss train: 1.010, val: n/a | iter time: 1586.98 ms\n",
            "Epoch 1 | iter 414 step 51 | loss train: 0.974, val: n/a | iter time: 5334.58 ms\n",
            "Epoch 1 | iter 415 step 51 | loss train: 0.908, val: n/a | iter time: 2630.93 ms\n",
            "Epoch 1 | iter 416 step 52 | loss train: 0.906, val: n/a | iter time: 5035.45 ms (step)\n",
            "Epoch 1 | iter 417 step 52 | loss train: 0.918, val: n/a | iter time: 2981.06 ms\n",
            "Epoch 1 | iter 418 step 52 | loss train: 0.928, val: n/a | iter time: 1588.62 ms\n",
            "Epoch 1 | iter 419 step 52 | loss train: 0.962, val: n/a | iter time: 3935.87 ms\n",
            "Epoch 1 | iter 420 step 52 | loss train: 0.946, val: n/a | iter time: 902.43 ms\n",
            "Epoch 1 | iter 421 step 52 | loss train: 0.985, val: n/a | iter time: 3290.24 ms\n",
            "Epoch 1 | iter 422 step 52 | loss train: 1.022, val: n/a | iter time: 1589.11 ms\n",
            "Epoch 1 | iter 423 step 52 | loss train: 1.032, val: n/a | iter time: 3259.79 ms\n",
            "Epoch 1 | iter 424 step 53 | loss train: 1.053, val: n/a | iter time: 2972.41 ms (step)\n",
            "Epoch 1 | iter 425 step 53 | loss train: 1.021, val: n/a | iter time: 978.76 ms\n",
            "Epoch 1 | iter 426 step 53 | loss train: 1.041, val: n/a | iter time: 5079.35 ms\n",
            "Epoch 1 | iter 427 step 53 | loss train: 1.013, val: n/a | iter time: 1222.89 ms\n",
            "Epoch 1 | iter 428 step 53 | loss train: 1.058, val: n/a | iter time: 3316.11 ms\n",
            "Epoch 1 | iter 429 step 53 | loss train: 1.030, val: n/a | iter time: 3292.01 ms\n",
            "Epoch 1 | iter 430 step 53 | loss train: 0.986, val: n/a | iter time: 3688.01 ms\n",
            "Epoch 1 | iter 431 step 53 | loss train: 0.982, val: n/a | iter time: 1904.93 ms\n",
            "Epoch 1 | iter 432 step 54 | loss train: 0.963, val: n/a | iter time: 907.35 ms (step)\n",
            "Epoch 1 | iter 433 step 54 | loss train: 1.001, val: n/a | iter time: 5493.75 ms\n",
            "Epoch 1 | iter 434 step 54 | loss train: 0.981, val: n/a | iter time: 4034.54 ms\n",
            "Epoch 1 | iter 435 step 54 | loss train: 1.012, val: n/a | iter time: 4049.00 ms\n",
            "Epoch 1 | iter 436 step 54 | loss train: 0.966, val: n/a | iter time: 1697.05 ms\n",
            "Epoch 1 | iter 437 step 54 | loss train: 0.971, val: n/a | iter time: 3988.22 ms\n",
            "Epoch 1 | iter 438 step 54 | loss train: 1.030, val: n/a | iter time: 3644.42 ms\n",
            "Epoch 1 | iter 439 step 54 | loss train: 1.079, val: n/a | iter time: 1491.16 ms\n",
            "Epoch 1 | iter 440 step 55 | loss train: 1.099, val: n/a | iter time: 4379.20 ms (step)\n",
            "Epoch 1 | iter 441 step 55 | loss train: 1.080, val: n/a | iter time: 2365.30 ms\n",
            "Epoch 1 | iter 442 step 55 | loss train: 1.101, val: n/a | iter time: 2599.35 ms\n",
            "Epoch 1 | iter 443 step 55 | loss train: 1.121, val: n/a | iter time: 2619.00 ms\n",
            "Epoch 1 | iter 444 step 55 | loss train: 1.186, val: n/a | iter time: 1236.78 ms\n",
            "Epoch 1 | iter 445 step 55 | loss train: 1.131, val: n/a | iter time: 4242.49 ms\n",
            "Epoch 1 | iter 446 step 55 | loss train: 1.098, val: n/a | iter time: 4701.01 ms\n",
            "Epoch 1 | iter 447 step 55 | loss train: 1.042, val: n/a | iter time: 2996.41 ms\n",
            "Epoch 1 | iter 448 step 56 | loss train: 1.032, val: n/a | iter time: 980.87 ms (step)\n",
            "Epoch 1 | iter 449 step 56 | loss train: 1.005, val: n/a | iter time: 875.09 ms\n",
            "Epoch 1 | iter 450 step 56 | loss train: 1.014, val: n/a | iter time: 992.43 ms\n",
            "Epoch 1 | iter 451 step 56 | loss train: 0.963, val: n/a | iter time: 2513.10 ms\n",
            "Epoch 1 | iter 452 step 56 | loss train: 0.894, val: n/a | iter time: 3886.01 ms\n",
            "Epoch 1 | iter 453 step 56 | loss train: 0.921, val: n/a | iter time: 1585.38 ms\n",
            "Epoch 1 | iter 454 step 56 | loss train: 0.899, val: n/a | iter time: 4405.84 ms\n",
            "Epoch 1 | iter 455 step 56 | loss train: 0.911, val: n/a | iter time: 1316.34 ms\n",
            "Epoch 1 | iter 456 step 57 | loss train: 0.985, val: n/a | iter time: 1962.87 ms (step)\n",
            "Epoch 1 | iter 457 step 57 | loss train: 0.998, val: n/a | iter time: 1908.44 ms\n",
            "Epoch 1 | iter 458 step 57 | loss train: 1.051, val: n/a | iter time: 2649.41 ms\n",
            "Epoch 1 | iter 459 step 57 | loss train: 1.065, val: n/a | iter time: 3887.25 ms\n",
            "Epoch 1 | iter 460 step 57 | loss train: 1.103, val: n/a | iter time: 871.37 ms\n",
            "Epoch 1 | iter 461 step 57 | loss train: 1.154, val: n/a | iter time: 877.21 ms\n",
            "Epoch 1 | iter 462 step 57 | loss train: 1.175, val: n/a | iter time: 1315.90 ms\n",
            "Epoch 1 | iter 463 step 57 | loss train: 1.172, val: n/a | iter time: 5404.72 ms\n",
            "Epoch 1 | iter 464 step 58 | loss train: 1.063, val: n/a | iter time: 589.51 ms (step)\n",
            "Epoch 1 | iter 465 step 58 | loss train: 1.078, val: n/a | iter time: 3880.38 ms\n",
            "Epoch 1 | iter 466 step 58 | loss train: 0.969, val: n/a | iter time: 996.77 ms\n",
            "Epoch 1 | iter 467 step 58 | loss train: 0.946, val: n/a | iter time: 3964.11 ms\n",
            "Epoch 1 | iter 468 step 58 | loss train: 0.943, val: n/a | iter time: 3965.25 ms\n",
            "Epoch 1 | iter 469 step 58 | loss train: 0.909, val: n/a | iter time: 4012.62 ms\n",
            "Epoch 1 | iter 470 step 58 | loss train: 0.895, val: n/a | iter time: 3634.15 ms\n",
            "Epoch 1 | iter 471 step 58 | loss train: 0.915, val: n/a | iter time: 975.04 ms\n",
            "Epoch 1 | iter 472 step 59 | loss train: 0.917, val: n/a | iter time: 1324.83 ms (step)\n",
            "Epoch 1 | iter 473 step 59 | loss train: 0.940, val: n/a | iter time: 1675.73 ms\n",
            "Epoch 1 | iter 474 step 59 | loss train: 0.992, val: n/a | iter time: 4251.93 ms\n",
            "Epoch 1 | iter 475 step 59 | loss train: 1.014, val: n/a | iter time: 3287.09 ms\n",
            "Epoch 1 | iter 476 step 59 | loss train: 0.989, val: n/a | iter time: 1882.48 ms\n",
            "Epoch 1 | iter 477 step 59 | loss train: 0.958, val: n/a | iter time: 3980.95 ms\n",
            "Epoch 1 | iter 478 step 59 | loss train: 0.946, val: n/a | iter time: 965.43 ms\n",
            "Epoch 1 | iter 479 step 59 | loss train: 0.982, val: n/a | iter time: 2306.75 ms\n",
            "Epoch 1 | iter 480 step 60 | loss train: 1.007, val: n/a | iter time: 974.94 ms (step)\n",
            "Epoch 1 | iter 481 step 60 | loss train: 0.983, val: n/a | iter time: 1580.26 ms\n",
            "Epoch 1 | iter 482 step 60 | loss train: 0.954, val: n/a | iter time: 5363.97 ms\n",
            "Epoch 1 | iter 483 step 60 | loss train: 0.997, val: n/a | iter time: 1951.43 ms\n",
            "Epoch 1 | iter 484 step 60 | loss train: 1.031, val: n/a | iter time: 3885.47 ms\n",
            "Epoch 1 | iter 485 step 60 | loss train: 1.073, val: n/a | iter time: 5134.32 ms\n",
            "Epoch 1 | iter 486 step 60 | loss train: 1.070, val: n/a | iter time: 3698.90 ms\n",
            "Epoch 1 | iter 487 step 60 | loss train: 0.993, val: n/a | iter time: 986.03 ms\n",
            "Epoch 1 | iter 488 step 61 | loss train: 1.014, val: n/a | iter time: 4967.28 ms (step)\n",
            "Epoch 1 | iter 489 step 61 | loss train: 1.008, val: n/a | iter time: 2553.33 ms\n",
            "Epoch 1 | iter 490 step 61 | loss train: 1.044, val: n/a | iter time: 3626.03 ms\n",
            "Epoch 1 | iter 491 step 61 | loss train: 0.952, val: n/a | iter time: 2936.28 ms\n",
            "Epoch 1 | iter 492 step 61 | loss train: 0.993, val: n/a | iter time: 2338.14 ms\n",
            "Epoch 1 | iter 493 step 61 | loss train: 1.001, val: n/a | iter time: 4515.64 ms\n",
            "Epoch 1 | iter 494 step 61 | loss train: 1.042, val: n/a | iter time: 5128.01 ms\n",
            "Epoch 1 | iter 495 step 61 | loss train: 1.085, val: n/a | iter time: 2621.40 ms\n",
            "Epoch 1 | iter 496 step 62 | loss train: 1.071, val: n/a | iter time: 5389.75 ms (step)\n",
            "Epoch 1 | iter 497 step 62 | loss train: 1.106, val: n/a | iter time: 4832.51 ms\n",
            "Epoch 1 | iter 498 step 62 | loss train: 1.084, val: n/a | iter time: 1231.68 ms\n",
            "Epoch 1 | iter 499 step 62 | loss train: 1.142, val: n/a | iter time: 1234.13 ms\n",
            "Epoch 1 | iter 500 step 62 | loss train: 1.129, val: n/a | iter time: 5055.75 ms\n",
            "Epoch 1 | iter 501 step 62 | loss train: 1.085, val: n/a | iter time: 3981.75 ms\n",
            "Epoch 1 | iter 502 step 62 | loss train: 1.082, val: n/a | iter time: 3558.91 ms\n",
            "Epoch 1 | iter 503 step 62 | loss train: 1.040, val: n/a | iter time: 877.96 ms\n",
            "Epoch 1 | iter 504 step 63 | loss train: 1.041, val: n/a | iter time: 3662.52 ms (step)\n",
            "Epoch 1 | iter 505 step 63 | loss train: 0.940, val: n/a | iter time: 2623.42 ms\n",
            "Epoch 1 | iter 506 step 63 | loss train: 0.920, val: n/a | iter time: 893.18 ms\n",
            "Epoch 1 | iter 507 step 63 | loss train: 0.940, val: n/a | iter time: 3287.58 ms\n",
            "Epoch 1 | iter 508 step 63 | loss train: 0.868, val: n/a | iter time: 3905.18 ms\n",
            "Epoch 1 | iter 509 step 63 | loss train: 0.972, val: n/a | iter time: 3264.06 ms\n",
            "Epoch 1 | iter 510 step 63 | loss train: 0.985, val: n/a | iter time: 2320.29 ms\n",
            "Epoch 1 | iter 511 step 63 | loss train: 1.036, val: n/a | iter time: 2284.26 ms\n",
            "Epoch 1 | iter 512 step 64 | loss train: 1.056, val: n/a | iter time: 2222.75 ms (step)\n",
            "Epoch 1 | iter 513 step 64 | loss train: 1.170, val: n/a | iter time: 3930.26 ms\n",
            "Epoch 1 | iter 514 step 64 | loss train: 1.161, val: n/a | iter time: 977.01 ms\n",
            "Epoch 1 | iter 515 step 64 | loss train: 1.161, val: n/a | iter time: 2209.63 ms\n",
            "Epoch 1 | iter 516 step 64 | loss train: 1.189, val: n/a | iter time: 3262.83 ms\n",
            "Epoch 1 | iter 517 step 64 | loss train: 1.124, val: n/a | iter time: 3788.06 ms\n",
            "Epoch 1 | iter 518 step 64 | loss train: 1.075, val: n/a | iter time: 4462.12 ms\n",
            "Epoch 1 | iter 519 step 64 | loss train: 1.028, val: n/a | iter time: 878.94 ms\n",
            "Epoch 1 | iter 520 step 65 | loss train: 0.978, val: n/a | iter time: 689.21 ms (step)\n",
            "Epoch 1 | iter 521 step 65 | loss train: 0.922, val: n/a | iter time: 2943.15 ms\n",
            "Epoch 1 | iter 522 step 65 | loss train: 0.932, val: n/a | iter time: 1983.27 ms\n",
            "Epoch 1 | iter 523 step 65 | loss train: 0.857, val: n/a | iter time: 874.37 ms\n",
            "Epoch 1 | iter 524 step 65 | loss train: 0.860, val: n/a | iter time: 3969.85 ms\n",
            "Epoch 1 | iter 525 step 65 | loss train: 0.800, val: n/a | iter time: 881.50 ms\n",
            "Epoch 1 | iter 526 step 65 | loss train: 0.837, val: n/a | iter time: 1331.05 ms\n",
            "Epoch 1 | iter 527 step 65 | loss train: 0.874, val: n/a | iter time: 1237.18 ms\n",
            "Epoch 1 | iter 528 step 66 | loss train: 0.919, val: n/a | iter time: 4407.21 ms (step)\n",
            "Epoch 1 | iter 529 step 66 | loss train: 0.963, val: n/a | iter time: 2601.27 ms\n",
            "Epoch 1 | iter 530 step 66 | loss train: 0.980, val: n/a | iter time: 3275.32 ms\n",
            "Epoch 1 | iter 531 step 66 | loss train: 1.040, val: n/a | iter time: 3903.00 ms\n",
            "Epoch 1 | iter 532 step 66 | loss train: 0.995, val: n/a | iter time: 5526.76 ms\n",
            "Epoch 1 | iter 533 step 66 | loss train: 1.033, val: n/a | iter time: 2311.95 ms\n",
            "Epoch 1 | iter 534 step 66 | loss train: 0.996, val: n/a | iter time: 3312.77 ms\n",
            "Epoch 1 | iter 535 step 66 | loss train: 1.029, val: n/a | iter time: 4136.31 ms\n",
            "Epoch 1 | iter 536 step 67 | loss train: 0.982, val: n/a | iter time: 2562.36 ms (step)\n",
            "Epoch 1 | iter 537 step 67 | loss train: 0.988, val: n/a | iter time: 3364.35 ms\n",
            "Epoch 1 | iter 538 step 67 | loss train: 0.975, val: n/a | iter time: 698.83 ms\n",
            "Epoch 1 | iter 539 step 67 | loss train: 0.939, val: n/a | iter time: 1203.96 ms\n",
            "Epoch 1 | iter 540 step 67 | loss train: 0.997, val: n/a | iter time: 2655.88 ms\n",
            "Epoch 1 | iter 541 step 67 | loss train: 1.100, val: n/a | iter time: 1664.07 ms\n",
            "Epoch 1 | iter 542 step 67 | loss train: 1.119, val: n/a | iter time: 1686.64 ms\n",
            "Epoch 1 | iter 543 step 67 | loss train: 1.092, val: n/a | iter time: 1255.22 ms\n",
            "Epoch 1 | iter 544 step 68 | loss train: 1.098, val: n/a | iter time: 5590.11 ms (step)\n",
            "Epoch 1 | iter 545 step 68 | loss train: 1.019, val: n/a | iter time: 875.18 ms\n",
            "Epoch 1 | iter 546 step 68 | loss train: 1.000, val: n/a | iter time: 1592.96 ms\n",
            "Epoch 1 | iter 547 step 68 | loss train: 1.044, val: n/a | iter time: 2315.27 ms\n",
            "Epoch 1 | iter 548 step 68 | loss train: 0.985, val: n/a | iter time: 1613.63 ms\n",
            "Epoch 1 | iter 549 step 68 | loss train: 0.862, val: n/a | iter time: 1681.44 ms\n",
            "Epoch 1 | iter 550 step 68 | loss train: 0.829, val: n/a | iter time: 1597.00 ms\n",
            "Epoch 1 | iter 551 step 68 | loss train: 0.833, val: n/a | iter time: 1963.90 ms\n",
            "Epoch 1 | iter 552 step 69 | loss train: 0.914, val: n/a | iter time: 3244.24 ms (step)\n",
            "Epoch 1 | iter 553 step 69 | loss train: 0.957, val: n/a | iter time: 2561.44 ms\n",
            "Epoch 1 | iter 554 step 69 | loss train: 0.992, val: n/a | iter time: 1914.41 ms\n",
            "Epoch 1 | iter 555 step 69 | loss train: 1.002, val: n/a | iter time: 2208.55 ms\n",
            "Epoch 1 | iter 556 step 69 | loss train: 1.042, val: n/a | iter time: 992.74 ms\n",
            "Epoch 1 | iter 557 step 69 | loss train: 1.097, val: n/a | iter time: 1224.30 ms\n",
            "Epoch 1 | iter 558 step 69 | loss train: 1.145, val: n/a | iter time: 2671.78 ms\n",
            "Epoch 1 | iter 559 step 69 | loss train: 1.090, val: n/a | iter time: 3594.50 ms\n",
            "Epoch 1 | iter 560 step 70 | loss train: 0.999, val: n/a | iter time: 4016.60 ms (step)\n",
            "Epoch 1 | iter 561 step 70 | loss train: 0.963, val: n/a | iter time: 895.46 ms\n",
            "Epoch 1 | iter 562 step 70 | loss train: 0.997, val: n/a | iter time: 2614.24 ms\n",
            "Epoch 1 | iter 563 step 70 | loss train: 1.015, val: n/a | iter time: 4638.58 ms\n",
            "Epoch 1 | iter 564 step 70 | loss train: 1.043, val: n/a | iter time: 4179.52 ms\n",
            "Epoch 1 | iter 565 step 70 | loss train: 1.007, val: n/a | iter time: 2589.27 ms\n",
            "Epoch 1 | iter 566 step 70 | loss train: 0.969, val: n/a | iter time: 1900.54 ms\n",
            "Epoch 1 | iter 567 step 70 | loss train: 0.986, val: n/a | iter time: 4519.37 ms\n",
            "Epoch 1 | iter 568 step 71 | loss train: 0.997, val: n/a | iter time: 3907.46 ms (step)\n",
            "Epoch 1 | iter 569 step 71 | loss train: 1.042, val: n/a | iter time: 684.57 ms\n",
            "Epoch 1 | iter 570 step 71 | loss train: 0.990, val: n/a | iter time: 4441.65 ms\n",
            "Epoch 1 | iter 571 step 71 | loss train: 0.964, val: n/a | iter time: 4327.29 ms\n",
            "Epoch 1 | iter 572 step 71 | loss train: 0.935, val: n/a | iter time: 4015.12 ms\n",
            "Epoch 1 | iter 573 step 71 | loss train: 0.927, val: n/a | iter time: 1304.98 ms\n",
            "Epoch 1 | iter 574 step 71 | loss train: 0.874, val: n/a | iter time: 888.41 ms\n",
            "Epoch 1 | iter 575 step 71 | loss train: 0.892, val: n/a | iter time: 3274.89 ms\n",
            "Epoch 1 | iter 576 step 72 | loss train: 0.939, val: n/a | iter time: 1688.80 ms (step)\n",
            "Epoch 1 | iter 577 step 72 | loss train: 0.990, val: n/a | iter time: 4012.11 ms\n",
            "Epoch 1 | iter 578 step 72 | loss train: 1.033, val: n/a | iter time: 1704.25 ms\n",
            "Epoch 1 | iter 579 step 72 | loss train: 1.018, val: n/a | iter time: 5398.73 ms\n",
            "Epoch 1 | iter 580 step 72 | loss train: 1.073, val: n/a | iter time: 4018.01 ms\n",
            "Epoch 1 | iter 581 step 72 | loss train: 1.111, val: n/a | iter time: 2279.25 ms\n",
            "Epoch 1 | iter 582 step 72 | loss train: 1.184, val: n/a | iter time: 1888.08 ms\n",
            "Epoch 1 | iter 583 step 72 | loss train: 1.194, val: n/a | iter time: 3938.60 ms\n",
            "Epoch 1 | iter 584 step 73 | loss train: 1.203, val: n/a | iter time: 3356.16 ms (step)\n",
            "Epoch 1 | iter 585 step 73 | loss train: 1.134, val: n/a | iter time: 3900.81 ms\n",
            "Epoch 1 | iter 586 step 73 | loss train: 1.092, val: n/a | iter time: 1329.25 ms\n",
            "Epoch 1 | iter 587 step 73 | loss train: 1.063, val: n/a | iter time: 970.82 ms\n",
            "Epoch 1 | iter 588 step 73 | loss train: 0.992, val: n/a | iter time: 1887.48 ms\n",
            "Epoch 1 | iter 589 step 73 | loss train: 1.012, val: n/a | iter time: 2655.44 ms\n",
            "Epoch 1 | iter 590 step 73 | loss train: 0.997, val: n/a | iter time: 5558.90 ms\n",
            "Epoch 1 | iter 591 step 73 | loss train: 0.980, val: n/a | iter time: 2010.58 ms\n",
            "Epoch 1 | iter 592 step 74 | loss train: 0.927, val: n/a | iter time: 1914.10 ms (step)\n",
            "Epoch 1 | iter 593 step 74 | loss train: 0.890, val: n/a | iter time: 975.27 ms\n",
            "Epoch 1 | iter 594 step 74 | loss train: 0.907, val: n/a | iter time: 2217.51 ms\n",
            "Epoch 1 | iter 595 step 74 | loss train: 0.948, val: n/a | iter time: 3010.32 ms\n",
            "Epoch 1 | iter 596 step 74 | loss train: 0.993, val: n/a | iter time: 1906.88 ms\n",
            "Epoch 1 | iter 597 step 74 | loss train: 0.969, val: n/a | iter time: 1978.37 ms\n",
            "Epoch 1 | iter 598 step 74 | loss train: 1.018, val: n/a | iter time: 1227.36 ms\n",
            "Epoch 1 | iter 599 step 74 | loss train: 1.059, val: n/a | iter time: 4712.70 ms\n",
            "Epoch 1 | iter 600 step 75 | loss train: 1.073, val: n/a | iter time: 1906.00 ms (step)\n",
            "Epoch 1 | iter 601 step 75 | loss train: 1.098, val: n/a | iter time: 1916.25 ms\n",
            "Epoch 1 | iter 602 step 75 | loss train: 1.104, val: n/a | iter time: 896.13 ms\n",
            "Epoch 1 | iter 603 step 75 | loss train: 1.076, val: n/a | iter time: 3603.19 ms\n",
            "Epoch 1 | iter 604 step 75 | loss train: 1.096, val: n/a | iter time: 3223.57 ms\n",
            "Epoch 1 | iter 605 step 75 | loss train: 1.102, val: n/a | iter time: 4058.16 ms\n",
            "Epoch 1 | iter 606 step 75 | loss train: 1.109, val: n/a | iter time: 1599.88 ms\n",
            "Epoch 1 | iter 607 step 75 | loss train: 1.094, val: n/a | iter time: 1991.10 ms\n",
            "Epoch 1 | iter 608 step 76 | loss train: 1.097, val: n/a | iter time: 2237.01 ms (step)\n",
            "Epoch 1 | iter 609 step 76 | loss train: 1.095, val: n/a | iter time: 1597.28 ms\n",
            "Epoch 1 | iter 610 step 76 | loss train: 1.072, val: n/a | iter time: 5413.42 ms\n",
            "Epoch 1 | iter 611 step 76 | loss train: 1.153, val: n/a | iter time: 2570.28 ms\n",
            "Epoch 1 | iter 612 step 76 | loss train: 1.097, val: n/a | iter time: 2378.19 ms\n",
            "Epoch 1 | iter 613 step 76 | loss train: 1.058, val: n/a | iter time: 5465.75 ms\n",
            "Epoch 1 | iter 614 step 76 | loss train: 0.953, val: n/a | iter time: 3900.31 ms\n",
            "Epoch 1 | iter 615 step 76 | loss train: 0.968, val: n/a | iter time: 3063.20 ms\n",
            "Epoch 1 | iter 616 step 77 | loss train: 0.973, val: n/a | iter time: 1247.74 ms (step)\n",
            "Epoch 1 | iter 617 step 77 | loss train: 0.996, val: n/a | iter time: 1665.14 ms\n",
            "Epoch 1 | iter 618 step 77 | loss train: 1.004, val: n/a | iter time: 2329.83 ms\n",
            "Epoch 1 | iter 619 step 77 | loss train: 0.931, val: n/a | iter time: 979.76 ms\n",
            "Epoch 1 | iter 620 step 77 | loss train: 0.943, val: n/a | iter time: 1924.29 ms\n",
            "Epoch 1 | iter 621 step 77 | loss train: 0.944, val: n/a | iter time: 5299.52 ms\n",
            "Epoch 1 | iter 622 step 77 | loss train: 0.996, val: n/a | iter time: 4460.45 ms\n",
            "Epoch 1 | iter 623 step 77 | loss train: 0.979, val: n/a | iter time: 4773.39 ms\n",
            "Epoch 1 | iter 624 step 78 | loss train: 0.918, val: n/a | iter time: 972.84 ms (step)\n",
            "Epoch 1 | iter 625 step 78 | loss train: 0.915, val: n/a | iter time: 1887.13 ms\n",
            "Epoch 1 | iter 626 step 78 | loss train: 0.936, val: n/a | iter time: 1223.97 ms\n",
            "Epoch 1 | iter 627 step 78 | loss train: 0.946, val: n/a | iter time: 5074.21 ms\n",
            "Epoch 1 | iter 628 step 78 | loss train: 0.925, val: n/a | iter time: 3001.14 ms\n",
            "Epoch 1 | iter 629 step 78 | loss train: 0.941, val: n/a | iter time: 985.34 ms\n",
            "Epoch 1 | iter 630 step 78 | loss train: 1.014, val: n/a | iter time: 3909.44 ms\n",
            "Epoch 1 | iter 631 step 78 | loss train: 1.002, val: n/a | iter time: 1656.54 ms\n",
            "Epoch 1 | iter 632 step 79 | loss train: 1.025, val: n/a | iter time: 3620.53 ms (step)\n",
            "Epoch 1 | iter 633 step 79 | loss train: 1.068, val: n/a | iter time: 2641.56 ms\n",
            "Epoch 1 | iter 634 step 79 | loss train: 1.076, val: n/a | iter time: 1238.89 ms\n",
            "Epoch 1 | iter 635 step 79 | loss train: 1.085, val: n/a | iter time: 5268.81 ms\n",
            "Epoch 1 | iter 636 step 79 | loss train: 1.108, val: n/a | iter time: 1670.37 ms\n",
            "Epoch 1 | iter 637 step 79 | loss train: 1.111, val: n/a | iter time: 4311.70 ms\n",
            "Epoch 1 | iter 638 step 79 | loss train: 1.040, val: n/a | iter time: 4709.10 ms\n",
            "Epoch 1 | iter 639 step 79 | loss train: 1.061, val: n/a | iter time: 4123.61 ms\n",
            "Epoch 1 | iter 640 step 80 | loss train: 1.126, val: n/a | iter time: 5229.46 ms (step)\n",
            "Epoch 1 | iter 641 step 80 | loss train: 1.123, val: n/a | iter time: 2637.75 ms\n",
            "Epoch 1 | iter 642 step 80 | loss train: 1.118, val: n/a | iter time: 2628.92 ms\n",
            "Epoch 1 | iter 643 step 80 | loss train: 1.142, val: n/a | iter time: 4359.70 ms\n",
            "Epoch 1 | iter 644 step 80 | loss train: 1.160, val: n/a | iter time: 2582.45 ms\n",
            "Epoch 1 | iter 645 step 80 | loss train: 1.192, val: n/a | iter time: 1891.10 ms\n",
            "Epoch 1 | iter 646 step 80 | loss train: 1.211, val: n/a | iter time: 3302.72 ms\n",
            "Epoch 1 | iter 647 step 80 | loss train: 1.237, val: n/a | iter time: 4696.39 ms\n",
            "Epoch 1 | iter 648 step 81 | loss train: 1.190, val: n/a | iter time: 3296.18 ms (step)\n",
            "Epoch 1 | iter 649 step 81 | loss train: 1.171, val: n/a | iter time: 1593.64 ms\n",
            "Epoch 1 | iter 650 step 81 | loss train: 1.193, val: n/a | iter time: 3273.04 ms\n",
            "Epoch 1 | iter 651 step 81 | loss train: 1.163, val: n/a | iter time: 2929.15 ms\n",
            "Epoch 1 | iter 652 step 81 | loss train: 1.119, val: n/a | iter time: 1233.94 ms\n",
            "Epoch 1 | iter 653 step 81 | loss train: 1.099, val: n/a | iter time: 5225.17 ms\n",
            "Epoch 1 | iter 654 step 81 | loss train: 1.082, val: n/a | iter time: 2255.69 ms\n",
            "Epoch 1 | iter 655 step 81 | loss train: 1.032, val: n/a | iter time: 5394.27 ms\n",
            "Epoch 1 | iter 656 step 82 | loss train: 1.040, val: n/a | iter time: 4343.44 ms (step)\n",
            "Epoch 1 | iter 657 step 82 | loss train: 0.991, val: n/a | iter time: 981.83 ms\n",
            "Epoch 1 | iter 658 step 82 | loss train: 0.958, val: n/a | iter time: 1332.18 ms\n",
            "Epoch 1 | iter 659 step 82 | loss train: 0.935, val: n/a | iter time: 1230.93 ms\n",
            "Epoch 1 | iter 660 step 82 | loss train: 0.919, val: n/a | iter time: 1669.43 ms\n",
            "Epoch 1 | iter 661 step 82 | loss train: 0.910, val: n/a | iter time: 3223.49 ms\n",
            "Epoch 1 | iter 662 step 82 | loss train: 0.930, val: n/a | iter time: 1608.52 ms\n",
            "Epoch 1 | iter 663 step 82 | loss train: 0.986, val: n/a | iter time: 4285.75 ms\n",
            "Epoch 1 | iter 664 step 83 | loss train: 1.030, val: n/a | iter time: 1968.18 ms (step)\n",
            "Epoch 1 | iter 665 step 83 | loss train: 1.091, val: n/a | iter time: 4682.64 ms\n",
            "Epoch 1 | iter 666 step 83 | loss train: 1.085, val: n/a | iter time: 1898.55 ms\n",
            "Epoch 1 | iter 667 step 83 | loss train: 1.087, val: n/a | iter time: 3259.41 ms\n",
            "Epoch 1 | iter 668 step 83 | loss train: 1.077, val: n/a | iter time: 957.60 ms\n",
            "Epoch 1 | iter 669 step 83 | loss train: 1.022, val: n/a | iter time: 894.21 ms\n",
            "Epoch 1 | iter 670 step 83 | loss train: 1.001, val: n/a | iter time: 2983.00 ms\n",
            "Epoch 1 | iter 671 step 83 | loss train: 0.932, val: n/a | iter time: 3183.37 ms\n",
            "Epoch 1 | iter 672 step 84 | loss train: 0.891, val: n/a | iter time: 3332.97 ms (step)\n",
            "Epoch 1 | iter 673 step 84 | loss train: 0.787, val: n/a | iter time: 991.78 ms\n",
            "Epoch 1 | iter 674 step 84 | loss train: 0.773, val: n/a | iter time: 995.62 ms\n",
            "Epoch 1 | iter 675 step 84 | loss train: 0.764, val: n/a | iter time: 4530.61 ms\n",
            "Epoch 1 | iter 676 step 84 | loss train: 0.856, val: n/a | iter time: 2372.11 ms\n",
            "Epoch 1 | iter 677 step 84 | loss train: 0.887, val: n/a | iter time: 3178.94 ms\n",
            "Epoch 1 | iter 678 step 84 | loss train: 0.903, val: n/a | iter time: 4684.77 ms\n",
            "Epoch 1 | iter 679 step 84 | loss train: 0.919, val: n/a | iter time: 3274.95 ms\n",
            "Epoch 1 | iter 680 step 85 | loss train: 0.927, val: n/a | iter time: 3765.37 ms (step)\n",
            "Epoch 1 | iter 681 step 85 | loss train: 0.978, val: n/a | iter time: 983.64 ms\n",
            "Epoch 1 | iter 682 step 85 | loss train: 0.998, val: n/a | iter time: 1595.11 ms\n",
            "Epoch 1 | iter 683 step 85 | loss train: 0.963, val: n/a | iter time: 3000.26 ms\n",
            "Epoch 1 | iter 684 step 85 | loss train: 0.917, val: n/a | iter time: 1893.26 ms\n",
            "Epoch 1 | iter 685 step 85 | loss train: 0.925, val: n/a | iter time: 1592.77 ms\n",
            "Epoch 1 | iter 686 step 85 | loss train: 0.924, val: n/a | iter time: 2923.25 ms\n",
            "Epoch 1 | iter 687 step 85 | loss train: 0.867, val: n/a | iter time: 5376.26 ms\n",
            "Epoch 1 | iter 688 step 86 | loss train: 0.857, val: n/a | iter time: 3283.13 ms (step)\n",
            "Epoch 1 | iter 689 step 86 | loss train: 0.912, val: n/a | iter time: 5253.72 ms\n",
            "Epoch 1 | iter 690 step 86 | loss train: 0.917, val: n/a | iter time: 1889.32 ms\n",
            "Epoch 1 | iter 691 step 86 | loss train: 0.979, val: n/a | iter time: 5413.50 ms\n",
            "Epoch 1 | iter 692 step 86 | loss train: 0.985, val: n/a | iter time: 4831.16 ms\n",
            "Epoch 1 | iter 693 step 86 | loss train: 1.008, val: n/a | iter time: 2580.67 ms\n",
            "Epoch 1 | iter 694 step 86 | loss train: 0.992, val: n/a | iter time: 1293.13 ms\n",
            "Epoch 1 | iter 695 step 86 | loss train: 1.057, val: n/a | iter time: 3943.89 ms\n",
            "Epoch 1 | iter 696 step 87 | loss train: 1.085, val: n/a | iter time: 1956.65 ms (step)\n",
            "Epoch 1 | iter 697 step 87 | loss train: 1.062, val: n/a | iter time: 3720.06 ms\n",
            "Epoch 1 | iter 698 step 87 | loss train: 1.103, val: n/a | iter time: 955.96 ms\n",
            "Epoch 1 | iter 699 step 87 | loss train: 1.090, val: n/a | iter time: 3253.90 ms\n",
            "Epoch 1 | iter 700 step 87 | loss train: 1.111, val: n/a | iter time: 4541.32 ms\n",
            "Epoch 1 | iter 701 step 87 | loss train: 1.082, val: n/a | iter time: 2216.31 ms\n",
            "Epoch 1 | iter 702 step 87 | loss train: 1.093, val: n/a | iter time: 2610.59 ms\n",
            "Epoch 1 | iter 703 step 87 | loss train: 1.067, val: n/a | iter time: 882.53 ms\n",
            "Epoch 1 | iter 704 step 88 | loss train: 1.007, val: n/a | iter time: 2928.81 ms (step)\n",
            "Epoch 1 | iter 705 step 88 | loss train: 1.018, val: n/a | iter time: 3372.93 ms\n",
            "Epoch 1 | iter 706 step 88 | loss train: 0.955, val: n/a | iter time: 1607.41 ms\n",
            "Epoch 1 | iter 707 step 88 | loss train: 0.963, val: n/a | iter time: 2225.93 ms\n",
            "Epoch 1 | iter 708 step 88 | loss train: 0.891, val: n/a | iter time: 2592.52 ms\n",
            "Epoch 1 | iter 709 step 88 | loss train: 0.883, val: n/a | iter time: 2567.64 ms\n",
            "Epoch 1 | iter 710 step 88 | loss train: 0.884, val: n/a | iter time: 3935.15 ms\n",
            "Epoch 1 | iter 711 step 88 | loss train: 0.935, val: n/a | iter time: 3938.50 ms\n",
            "Epoch 1 | iter 712 step 89 | loss train: 1.022, val: n/a | iter time: 1694.64 ms (step)\n",
            "Epoch 1 | iter 713 step 89 | loss train: 0.970, val: n/a | iter time: 3591.99 ms\n",
            "Epoch 1 | iter 714 step 89 | loss train: 0.981, val: n/a | iter time: 1893.15 ms\n",
            "Epoch 1 | iter 715 step 89 | loss train: 0.981, val: n/a | iter time: 5454.98 ms\n",
            "Epoch 1 | iter 716 step 89 | loss train: 1.023, val: n/a | iter time: 1872.02 ms\n",
            "Epoch 1 | iter 717 step 89 | loss train: 1.052, val: n/a | iter time: 3898.02 ms\n",
            "Epoch 1 | iter 718 step 89 | loss train: 1.071, val: n/a | iter time: 4317.64 ms\n",
            "Epoch 1 | iter 719 step 89 | loss train: 1.069, val: n/a | iter time: 4285.42 ms\n",
            "Epoch 1 | iter 720 step 90 | loss train: 0.997, val: n/a | iter time: 2595.93 ms (step)\n",
            "Epoch 1 | iter 721 step 90 | loss train: 0.992, val: n/a | iter time: 896.61 ms\n",
            "Epoch 1 | iter 722 step 90 | loss train: 0.985, val: n/a | iter time: 2225.79 ms\n",
            "Epoch 1 | iter 723 step 90 | loss train: 1.024, val: n/a | iter time: 3647.59 ms\n",
            "Epoch 1 | iter 724 step 90 | loss train: 0.991, val: n/a | iter time: 2672.71 ms\n",
            "Epoch 1 | iter 725 step 90 | loss train: 0.994, val: n/a | iter time: 2594.16 ms\n",
            "Epoch 1 | iter 726 step 90 | loss train: 0.968, val: n/a | iter time: 4481.69 ms\n",
            "Epoch 1 | iter 727 step 90 | loss train: 0.965, val: n/a | iter time: 1897.56 ms\n",
            "Epoch 1 | iter 728 step 91 | loss train: 0.977, val: n/a | iter time: 3998.54 ms (step)\n",
            "Epoch 1 | iter 729 step 91 | loss train: 1.020, val: n/a | iter time: 898.16 ms\n",
            "Epoch 1 | iter 730 step 91 | loss train: 1.056, val: n/a | iter time: 1590.04 ms\n",
            "Epoch 1 | iter 731 step 91 | loss train: 0.992, val: n/a | iter time: 2228.87 ms\n",
            "Epoch 1 | iter 732 step 91 | loss train: 1.021, val: n/a | iter time: 3046.33 ms\n",
            "Epoch 1 | iter 733 step 91 | loss train: 1.017, val: n/a | iter time: 3786.78 ms\n",
            "Epoch 1 | iter 734 step 91 | loss train: 1.029, val: n/a | iter time: 1218.96 ms\n",
            "Epoch 1 | iter 735 step 91 | loss train: 1.004, val: n/a | iter time: 2220.11 ms\n",
            "Epoch 1 | iter 736 step 92 | loss train: 1.025, val: n/a | iter time: 4496.85 ms (step)\n",
            "Epoch 1 | iter 737 step 92 | loss train: 1.016, val: n/a | iter time: 3672.80 ms\n",
            "Epoch 1 | iter 738 step 92 | loss train: 1.029, val: n/a | iter time: 1705.73 ms\n",
            "Epoch 1 | iter 739 step 92 | loss train: 1.012, val: n/a | iter time: 1330.64 ms\n",
            "Epoch 1 | iter 740 step 92 | loss train: 0.952, val: n/a | iter time: 1302.84 ms\n",
            "Epoch 1 | iter 741 step 92 | loss train: 0.976, val: n/a | iter time: 4010.71 ms\n",
            "Epoch 1 | iter 742 step 92 | loss train: 0.976, val: n/a | iter time: 1305.66 ms\n",
            "Epoch 1 | iter 743 step 92 | loss train: 0.955, val: n/a | iter time: 1572.92 ms\n",
            "Epoch 1 | iter 744 step 93 | loss train: 0.976, val: n/a | iter time: 1607.07 ms (step)\n",
            "Epoch 1 | iter 745 step 93 | loss train: 0.948, val: n/a | iter time: 1217.62 ms\n",
            "Epoch 1 | iter 746 step 93 | loss train: 0.909, val: n/a | iter time: 1975.07 ms\n",
            "Epoch 1 | iter 747 step 93 | loss train: 0.927, val: n/a | iter time: 3281.98 ms\n",
            "Epoch 1 | iter 748 step 93 | loss train: 1.001, val: n/a | iter time: 2222.21 ms\n",
            "Epoch 1 | iter 749 step 93 | loss train: 1.007, val: n/a | iter time: 2211.68 ms\n",
            "Epoch 1 | iter 750 step 93 | loss train: 0.993, val: n/a | iter time: 4766.03 ms\n",
            "Epoch 1 | iter 751 step 93 | loss train: 0.987, val: n/a | iter time: 1237.26 ms\n",
            "Epoch 1 | iter 752 step 94 | loss train: 0.965, val: n/a | iter time: 3217.84 ms (step)\n",
            "Epoch 1 | iter 753 step 94 | loss train: 0.997, val: n/a | iter time: 2187.27 ms\n",
            "Epoch 1 | iter 754 step 94 | loss train: 0.994, val: n/a | iter time: 3635.10 ms\n",
            "Epoch 1 | iter 755 step 94 | loss train: 1.034, val: n/a | iter time: 3633.77 ms\n",
            "Epoch 1 | iter 756 step 94 | loss train: 1.053, val: n/a | iter time: 3312.88 ms\n",
            "Epoch 1 | iter 757 step 94 | loss train: 1.027, val: n/a | iter time: 3216.66 ms\n",
            "Epoch 1 | iter 758 step 94 | loss train: 1.012, val: n/a | iter time: 3275.44 ms\n",
            "Epoch 1 | iter 759 step 94 | loss train: 1.050, val: n/a | iter time: 657.83 ms\n",
            "Epoch 1 | iter 760 step 95 | loss train: 1.078, val: n/a | iter time: 3699.22 ms (step)\n",
            "Epoch 1 | iter 761 step 95 | loss train: 1.077, val: n/a | iter time: 1687.91 ms\n",
            "Epoch 1 | iter 762 step 95 | loss train: 1.079, val: n/a | iter time: 3257.60 ms\n",
            "Epoch 1 | iter 763 step 95 | loss train: 1.007, val: n/a | iter time: 669.26 ms\n",
            "Epoch 1 | iter 764 step 95 | loss train: 0.997, val: n/a | iter time: 1589.33 ms\n",
            "Epoch 1 | iter 765 step 95 | loss train: 1.001, val: n/a | iter time: 905.86 ms\n",
            "Epoch 1 | iter 766 step 95 | loss train: 0.995, val: n/a | iter time: 1320.01 ms\n",
            "Epoch 1 | iter 767 step 95 | loss train: 0.981, val: n/a | iter time: 3027.10 ms\n",
            "Epoch 1 | iter 768 step 96 | loss train: 0.920, val: n/a | iter time: 4005.54 ms (step)\n",
            "Epoch 1 | iter 769 step 96 | loss train: 0.896, val: n/a | iter time: 3548.18 ms\n",
            "Epoch 1 | iter 770 step 96 | loss train: 0.964, val: n/a | iter time: 2625.38 ms\n",
            "Epoch 1 | iter 771 step 96 | loss train: 1.052, val: n/a | iter time: 3205.30 ms\n",
            "Epoch 1 | iter 772 step 96 | loss train: 1.031, val: n/a | iter time: 4691.52 ms\n",
            "Epoch 1 | iter 773 step 96 | loss train: 1.002, val: n/a | iter time: 5432.95 ms\n",
            "Epoch 1 | iter 774 step 96 | loss train: 1.065, val: n/a | iter time: 5191.70 ms\n",
            "Epoch 1 | iter 775 step 96 | loss train: 1.121, val: n/a | iter time: 3000.99 ms\n",
            "Epoch 1 | iter 776 step 97 | loss train: 1.158, val: n/a | iter time: 2231.16 ms (step)\n",
            "Epoch 1 | iter 777 step 97 | loss train: 1.174, val: n/a | iter time: 1966.67 ms\n",
            "Epoch 1 | iter 778 step 97 | loss train: 1.078, val: n/a | iter time: 1689.00 ms\n",
            "Epoch 1 | iter 779 step 97 | loss train: 1.095, val: n/a | iter time: 5360.22 ms\n",
            "Epoch 1 | iter 780 step 97 | loss train: 1.111, val: n/a | iter time: 4287.02 ms\n",
            "Epoch 1 | iter 781 step 97 | loss train: 1.088, val: n/a | iter time: 1216.10 ms\n",
            "Epoch 1 | iter 782 step 97 | loss train: 1.027, val: n/a | iter time: 4037.42 ms\n",
            "Epoch 1 | iter 783 step 97 | loss train: 0.973, val: n/a | iter time: 3911.48 ms\n",
            "Epoch 1 | iter 784 step 98 | loss train: 0.959, val: n/a | iter time: 3160.61 ms (step)\n",
            "Epoch 1 | iter 785 step 98 | loss train: 0.952, val: n/a | iter time: 3962.89 ms\n",
            "Epoch 1 | iter 786 step 98 | loss train: 0.979, val: n/a | iter time: 956.39 ms\n",
            "Epoch 1 | iter 787 step 98 | loss train: 0.876, val: n/a | iter time: 2285.17 ms\n",
            "Epoch 1 | iter 788 step 98 | loss train: 0.909, val: n/a | iter time: 970.05 ms\n",
            "Epoch 1 | iter 789 step 98 | loss train: 0.954, val: n/a | iter time: 2607.60 ms\n",
            "Epoch 1 | iter 790 step 98 | loss train: 0.988, val: n/a | iter time: 3263.03 ms\n",
            "Epoch 1 | iter 791 step 98 | loss train: 0.951, val: n/a | iter time: 969.96 ms\n",
            "Epoch 1 | iter 792 step 99 | loss train: 0.912, val: n/a | iter time: 3542.31 ms (step)\n",
            "Epoch 1 | iter 793 step 99 | loss train: 0.951, val: n/a | iter time: 2916.16 ms\n",
            "Epoch 1 | iter 794 step 99 | loss train: 0.955, val: n/a | iter time: 1943.58 ms\n",
            "Epoch 1 | iter 795 step 99 | loss train: 1.032, val: n/a | iter time: 973.27 ms\n",
            "Epoch 1 | iter 796 step 99 | loss train: 1.002, val: n/a | iter time: 3681.58 ms\n",
            "Epoch 1 | iter 797 step 99 | loss train: 0.989, val: n/a | iter time: 2238.27 ms\n",
            "Epoch 1 | iter 798 step 99 | loss train: 0.977, val: n/a | iter time: 3930.28 ms\n",
            "Epoch 1 | iter 799 step 99 | loss train: 1.013, val: n/a | iter time: 2253.92 ms\n",
            "Epoch 1 | iter 800 step 100 | loss train: 1.026, val: n/a | iter time: 2568.64 ms (step)\n",
            "Validating ...\n",
            "Generate a funny caption for the following photo.\n",
            "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
            "\n",
            "### Instruction:\n",
            "Generate a funny caption for the following photo.\n",
            "\n",
            "### Response:\n",
            "Here's a photo of a cat posing with a funny face!\n",
            "\n",
            "iter 800: val loss 0.9886, val time: 68844.65 ms\n",
            "Epoch 1 | iter 801 step 100 | loss train: 0.991, val: 0.989 | iter time: 5621.80 ms\n",
            "Epoch 1 | iter 802 step 100 | loss train: 0.974, val: 0.989 | iter time: 3612.34 ms\n",
            "Epoch 1 | iter 803 step 100 | loss train: 0.961, val: 0.989 | iter time: 3181.32 ms\n",
            "Epoch 1 | iter 804 step 100 | loss train: 0.955, val: 0.989 | iter time: 3263.35 ms\n",
            "Epoch 1 | iter 805 step 100 | loss train: 1.011, val: 0.989 | iter time: 2345.45 ms\n",
            "Epoch 1 | iter 806 step 100 | loss train: 1.015, val: 0.989 | iter time: 3270.78 ms\n",
            "Epoch 1 | iter 807 step 100 | loss train: 1.045, val: 0.989 | iter time: 3280.39 ms\n",
            "Epoch 1 | iter 808 step 101 | loss train: 1.087, val: 0.989 | iter time: 1905.34 ms (step)\n",
            "Epoch 1 | iter 809 step 101 | loss train: 1.063, val: 0.989 | iter time: 1237.19 ms\n",
            "Epoch 1 | iter 810 step 101 | loss train: 1.065, val: 0.989 | iter time: 3295.49 ms\n",
            "Epoch 1 | iter 811 step 101 | loss train: 1.058, val: 0.989 | iter time: 3581.96 ms\n",
            "Epoch 1 | iter 812 step 101 | loss train: 1.029, val: 0.989 | iter time: 3025.42 ms\n",
            "Epoch 1 | iter 813 step 101 | loss train: 0.987, val: 0.989 | iter time: 2540.23 ms\n",
            "Epoch 1 | iter 814 step 101 | loss train: 1.051, val: 0.989 | iter time: 4585.23 ms\n",
            "Epoch 1 | iter 815 step 101 | loss train: 1.021, val: 0.989 | iter time: 2347.89 ms\n",
            "Epoch 1 | iter 816 step 102 | loss train: 0.960, val: 0.989 | iter time: 2561.09 ms (step)\n",
            "Epoch 1 | iter 817 step 102 | loss train: 0.976, val: 0.989 | iter time: 3312.80 ms\n",
            "Epoch 1 | iter 818 step 102 | loss train: 0.972, val: 0.989 | iter time: 3954.90 ms\n",
            "Epoch 1 | iter 819 step 102 | loss train: 1.010, val: 0.989 | iter time: 3000.18 ms\n",
            "Epoch 1 | iter 820 step 102 | loss train: 1.048, val: 0.989 | iter time: 3572.12 ms\n",
            "Epoch 1 | iter 821 step 102 | loss train: 1.037, val: 0.989 | iter time: 4097.41 ms\n",
            "Epoch 1 | iter 822 step 102 | loss train: 0.989, val: 0.989 | iter time: 3660.53 ms\n",
            "Epoch 1 | iter 823 step 102 | loss train: 1.017, val: 0.989 | iter time: 3784.79 ms\n",
            "Epoch 1 | iter 824 step 103 | loss train: 1.086, val: 0.989 | iter time: 974.26 ms (step)\n",
            "Epoch 1 | iter 825 step 103 | loss train: 1.136, val: 0.989 | iter time: 4008.04 ms\n",
            "Epoch 1 | iter 826 step 103 | loss train: 1.158, val: 0.989 | iter time: 1666.03 ms\n",
            "Epoch 1 | iter 827 step 103 | loss train: 1.159, val: 0.989 | iter time: 2942.33 ms\n",
            "Epoch 1 | iter 828 step 103 | loss train: 1.171, val: 0.989 | iter time: 2359.50 ms\n",
            "Epoch 1 | iter 829 step 103 | loss train: 1.162, val: 0.989 | iter time: 2984.34 ms\n",
            "Epoch 1 | iter 830 step 103 | loss train: 1.131, val: 0.989 | iter time: 1211.06 ms\n",
            "Epoch 1 | iter 831 step 103 | loss train: 1.095, val: 0.989 | iter time: 2240.68 ms\n",
            "Epoch 1 | iter 832 step 104 | loss train: 1.110, val: 0.989 | iter time: 2249.58 ms (step)\n",
            "Epoch 1 | iter 833 step 104 | loss train: 1.081, val: 0.989 | iter time: 3009.71 ms\n",
            "Epoch 1 | iter 834 step 104 | loss train: 1.092, val: 0.989 | iter time: 1600.63 ms\n",
            "Epoch 1 | iter 835 step 104 | loss train: 1.034, val: 0.989 | iter time: 1348.22 ms\n",
            "Epoch 1 | iter 836 step 104 | loss train: 1.003, val: 0.989 | iter time: 2996.91 ms\n",
            "Epoch 1 | iter 837 step 104 | loss train: 0.989, val: 0.989 | iter time: 884.91 ms\n",
            "Epoch 1 | iter 838 step 104 | loss train: 1.037, val: 0.989 | iter time: 3981.70 ms\n",
            "Epoch 1 | iter 839 step 104 | loss train: 0.988, val: 0.989 | iter time: 973.48 ms\n",
            "Epoch 1 | iter 840 step 105 | loss train: 0.946, val: 0.989 | iter time: 1667.45 ms (step)\n",
            "Epoch 1 | iter 841 step 105 | loss train: 0.953, val: 0.989 | iter time: 5441.40 ms\n",
            "Epoch 1 | iter 842 step 105 | loss train: 0.914, val: 0.989 | iter time: 1588.92 ms\n",
            "Epoch 1 | iter 843 step 105 | loss train: 0.890, val: 0.989 | iter time: 2629.57 ms\n",
            "Epoch 1 | iter 844 step 105 | loss train: 0.916, val: 0.989 | iter time: 2920.09 ms\n",
            "Epoch 1 | iter 845 step 105 | loss train: 0.909, val: 0.989 | iter time: 1673.12 ms\n",
            "Epoch 1 | iter 846 step 105 | loss train: 0.847, val: 0.989 | iter time: 1964.64 ms\n",
            "Epoch 1 | iter 847 step 105 | loss train: 0.895, val: 0.989 | iter time: 1868.93 ms\n",
            "Epoch 1 | iter 848 step 106 | loss train: 0.913, val: 0.989 | iter time: 4565.15 ms (step)\n",
            "Epoch 1 | iter 849 step 106 | loss train: 0.910, val: 0.989 | iter time: 3269.44 ms\n",
            "Epoch 1 | iter 850 step 106 | loss train: 0.930, val: 0.989 | iter time: 2917.90 ms\n",
            "Epoch 1 | iter 851 step 106 | loss train: 1.006, val: 0.989 | iter time: 1590.90 ms\n",
            "Epoch 1 | iter 852 step 106 | loss train: 0.948, val: 0.989 | iter time: 4260.50 ms\n",
            "Epoch 1 | iter 853 step 106 | loss train: 0.956, val: 0.989 | iter time: 2214.07 ms\n",
            "Epoch 1 | iter 854 step 106 | loss train: 0.972, val: 0.989 | iter time: 3967.07 ms\n",
            "Epoch 1 | iter 855 step 106 | loss train: 1.015, val: 0.989 | iter time: 1966.31 ms\n",
            "Epoch 1 | iter 856 step 107 | loss train: 1.023, val: 0.989 | iter time: 1993.91 ms (step)\n",
            "Epoch 1 | iter 857 step 107 | loss train: 1.013, val: 0.989 | iter time: 3957.56 ms\n",
            "Epoch 1 | iter 858 step 107 | loss train: 1.028, val: 0.989 | iter time: 2326.55 ms\n",
            "Epoch 1 | iter 859 step 107 | loss train: 0.968, val: 0.989 | iter time: 651.34 ms\n",
            "Epoch 1 | iter 860 step 107 | loss train: 1.143, val: 0.989 | iter time: 4291.60 ms\n",
            "Epoch 1 | iter 861 step 107 | loss train: 1.151, val: 0.989 | iter time: 1686.46 ms\n",
            "Epoch 1 | iter 862 step 107 | loss train: 1.119, val: 0.989 | iter time: 1679.80 ms\n",
            "Epoch 1 | iter 863 step 107 | loss train: 1.061, val: 0.989 | iter time: 1211.26 ms\n",
            "Epoch 1 | iter 864 step 108 | loss train: 1.043, val: 0.989 | iter time: 3956.42 ms (step)\n",
            "Epoch 1 | iter 865 step 108 | loss train: 1.041, val: 0.989 | iter time: 877.74 ms\n",
            "Epoch 1 | iter 866 step 108 | loss train: 1.048, val: 0.989 | iter time: 5516.01 ms\n",
            "Epoch 1 | iter 867 step 108 | loss train: 1.058, val: 0.989 | iter time: 1598.64 ms\n",
            "Epoch 1 | iter 868 step 108 | loss train: 0.921, val: 0.989 | iter time: 5248.12 ms\n",
            "Epoch 1 | iter 869 step 108 | loss train: 0.956, val: 0.989 | iter time: 1315.86 ms\n",
            "Epoch 1 | iter 870 step 108 | loss train: 0.984, val: 0.989 | iter time: 5353.35 ms\n",
            "Epoch 1 | iter 871 step 108 | loss train: 0.955, val: 0.989 | iter time: 877.35 ms\n",
            "Epoch 1 | iter 872 step 109 | loss train: 0.935, val: 0.989 | iter time: 1247.39 ms (step)\n",
            "Epoch 1 | iter 873 step 109 | loss train: 0.925, val: 0.989 | iter time: 3614.44 ms\n",
            "Epoch 1 | iter 874 step 109 | loss train: 0.864, val: 0.989 | iter time: 960.31 ms\n",
            "Epoch 1 | iter 875 step 109 | loss train: 0.936, val: 0.989 | iter time: 2635.18 ms\n",
            "Epoch 1 | iter 876 step 109 | loss train: 0.914, val: 0.989 | iter time: 1968.42 ms\n",
            "Epoch 1 | iter 877 step 109 | loss train: 0.838, val: 0.989 | iter time: 880.43 ms\n",
            "Epoch 1 | iter 878 step 109 | loss train: 0.848, val: 0.989 | iter time: 3568.70 ms\n",
            "Epoch 1 | iter 879 step 109 | loss train: 0.916, val: 0.989 | iter time: 4773.78 ms\n",
            "Epoch 1 | iter 880 step 110 | loss train: 0.892, val: 0.989 | iter time: 5289.15 ms (step)\n",
            "Epoch 1 | iter 881 step 110 | loss train: 0.931, val: 0.989 | iter time: 5585.64 ms\n",
            "Epoch 1 | iter 882 step 110 | loss train: 0.974, val: 0.989 | iter time: 2614.24 ms\n",
            "Epoch 1 | iter 883 step 110 | loss train: 0.915, val: 0.989 | iter time: 3073.37 ms\n",
            "Epoch 1 | iter 884 step 110 | loss train: 0.956, val: 0.989 | iter time: 1249.15 ms\n",
            "Epoch 1 | iter 885 step 110 | loss train: 1.045, val: 0.989 | iter time: 3776.31 ms\n",
            "Epoch 1 | iter 886 step 110 | loss train: 1.065, val: 0.989 | iter time: 3271.75 ms\n",
            "Epoch 1 | iter 887 step 110 | loss train: 1.056, val: 0.989 | iter time: 2646.48 ms\n",
            "Epoch 1 | iter 888 step 111 | loss train: 1.078, val: 0.989 | iter time: 985.54 ms (step)\n",
            "Epoch 1 | iter 889 step 111 | loss train: 1.091, val: 0.989 | iter time: 1936.03 ms\n",
            "Epoch 1 | iter 890 step 111 | loss train: 1.069, val: 0.989 | iter time: 1319.00 ms\n",
            "Epoch 1 | iter 891 step 111 | loss train: 1.084, val: 0.989 | iter time: 5454.77 ms\n",
            "Epoch 1 | iter 892 step 111 | loss train: 1.019, val: 0.989 | iter time: 692.32 ms\n",
            "Epoch 1 | iter 893 step 111 | loss train: 0.985, val: 0.989 | iter time: 1896.82 ms\n",
            "Epoch 1 | iter 894 step 111 | loss train: 0.978, val: 0.989 | iter time: 4678.99 ms\n",
            "Epoch 1 | iter 895 step 111 | loss train: 0.987, val: 0.989 | iter time: 3003.61 ms\n",
            "Epoch 1 | iter 896 step 112 | loss train: 0.988, val: 0.989 | iter time: 871.85 ms (step)\n",
            "Epoch 1 | iter 897 step 112 | loss train: 0.961, val: 0.989 | iter time: 1954.60 ms\n",
            "Epoch 1 | iter 898 step 112 | loss train: 0.957, val: 0.989 | iter time: 2922.78 ms\n",
            "Epoch 1 | iter 899 step 112 | loss train: 1.006, val: 0.989 | iter time: 2347.79 ms\n",
            "Epoch 1 | iter 900 step 112 | loss train: 1.052, val: 0.989 | iter time: 5474.82 ms\n",
            "Epoch 1 | iter 901 step 112 | loss train: 1.077, val: 0.989 | iter time: 4765.45 ms\n",
            "Epoch 1 | iter 902 step 112 | loss train: 1.029, val: 0.989 | iter time: 1911.17 ms\n",
            "Epoch 1 | iter 903 step 112 | loss train: 1.018, val: 0.989 | iter time: 4433.55 ms\n",
            "Epoch 1 | iter 904 step 113 | loss train: 1.030, val: 0.989 | iter time: 3205.20 ms (step)\n",
            "Epoch 1 | iter 905 step 113 | loss train: 0.964, val: 0.989 | iter time: 884.30 ms\n",
            "Epoch 1 | iter 906 step 113 | loss train: 0.994, val: 0.989 | iter time: 2381.01 ms\n",
            "Epoch 1 | iter 907 step 113 | loss train: 0.932, val: 0.989 | iter time: 2627.88 ms\n",
            "Epoch 1 | iter 908 step 113 | loss train: 0.954, val: 0.989 | iter time: 1991.19 ms\n",
            "Epoch 1 | iter 909 step 113 | loss train: 0.920, val: 0.989 | iter time: 990.05 ms\n",
            "Epoch 1 | iter 910 step 113 | loss train: 0.908, val: 0.989 | iter time: 906.53 ms\n",
            "Epoch 1 | iter 911 step 113 | loss train: 0.899, val: 0.989 | iter time: 3905.83 ms\n",
            "Epoch 1 | iter 912 step 114 | loss train: 0.906, val: 0.989 | iter time: 998.50 ms (step)\n",
            "Epoch 1 | iter 913 step 114 | loss train: 0.935, val: 0.989 | iter time: 1669.96 ms\n",
            "Epoch 1 | iter 914 step 114 | loss train: 0.948, val: 0.989 | iter time: 3013.54 ms\n",
            "Epoch 1 | iter 915 step 114 | loss train: 0.983, val: 0.989 | iter time: 980.22 ms\n",
            "Epoch 1 | iter 916 step 114 | loss train: 0.978, val: 0.989 | iter time: 1243.21 ms\n",
            "Epoch 1 | iter 917 step 114 | loss train: 0.983, val: 0.989 | iter time: 1605.24 ms\n",
            "Epoch 1 | iter 918 step 114 | loss train: 1.036, val: 0.989 | iter time: 3900.88 ms\n",
            "Epoch 1 | iter 919 step 114 | loss train: 0.982, val: 0.989 | iter time: 3296.79 ms\n",
            "Epoch 1 | iter 920 step 115 | loss train: 0.989, val: 0.989 | iter time: 2221.27 ms (step)\n",
            "Epoch 1 | iter 921 step 115 | loss train: 1.015, val: 0.989 | iter time: 2992.25 ms\n",
            "Epoch 1 | iter 922 step 115 | loss train: 1.020, val: 0.989 | iter time: 2600.11 ms\n",
            "Epoch 1 | iter 923 step 115 | loss train: 0.979, val: 0.989 | iter time: 3886.29 ms\n",
            "Epoch 1 | iter 924 step 115 | loss train: 0.961, val: 0.989 | iter time: 3756.86 ms\n",
            "Epoch 1 | iter 925 step 115 | loss train: 0.994, val: 0.989 | iter time: 3007.14 ms\n",
            "Epoch 1 | iter 926 step 115 | loss train: 0.970, val: 0.989 | iter time: 996.94 ms\n",
            "Epoch 1 | iter 927 step 115 | loss train: 1.021, val: 0.989 | iter time: 4314.87 ms\n",
            "Epoch 1 | iter 928 step 116 | loss train: 1.016, val: 0.989 | iter time: 5054.28 ms (step)\n",
            "Epoch 1 | iter 929 step 116 | loss train: 0.976, val: 0.989 | iter time: 988.42 ms\n",
            "Epoch 1 | iter 930 step 116 | loss train: 0.957, val: 0.989 | iter time: 2236.74 ms\n",
            "Epoch 1 | iter 931 step 116 | loss train: 0.928, val: 0.989 | iter time: 1675.32 ms\n",
            "Epoch 1 | iter 932 step 116 | loss train: 0.910, val: 0.989 | iter time: 2314.80 ms\n",
            "Epoch 1 | iter 933 step 116 | loss train: 0.880, val: 0.989 | iter time: 1603.15 ms\n",
            "Epoch 1 | iter 934 step 116 | loss train: 0.912, val: 0.989 | iter time: 5004.78 ms\n",
            "Epoch 1 | iter 935 step 116 | loss train: 0.921, val: 0.989 | iter time: 5475.90 ms\n",
            "Epoch 1 | iter 936 step 117 | loss train: 0.908, val: 0.989 | iter time: 5457.01 ms (step)\n",
            "Epoch 1 | iter 937 step 117 | loss train: 0.955, val: 0.989 | iter time: 2986.70 ms\n",
            "Epoch 1 | iter 938 step 117 | loss train: 0.947, val: 0.989 | iter time: 2655.41 ms\n",
            "Epoch 1 | iter 939 step 117 | loss train: 0.991, val: 0.989 | iter time: 1582.65 ms\n",
            "Epoch 1 | iter 940 step 117 | loss train: 0.944, val: 0.989 | iter time: 891.04 ms\n",
            "Epoch 1 | iter 941 step 117 | loss train: 0.949, val: 0.989 | iter time: 1985.82 ms\n",
            "Epoch 1 | iter 942 step 117 | loss train: 0.907, val: 0.989 | iter time: 876.23 ms\n",
            "Epoch 1 | iter 943 step 117 | loss train: 0.871, val: 0.989 | iter time: 898.04 ms\n",
            "Epoch 1 | iter 944 step 118 | loss train: 0.876, val: 0.989 | iter time: 982.38 ms (step)\n",
            "Epoch 1 | iter 945 step 118 | loss train: 0.858, val: 0.989 | iter time: 1594.80 ms\n",
            "Epoch 1 | iter 946 step 118 | loss train: 0.830, val: 0.989 | iter time: 3701.00 ms\n",
            "Epoch 1 | iter 947 step 118 | loss train: 0.785, val: 0.989 | iter time: 882.71 ms\n",
            "Epoch 1 | iter 948 step 118 | loss train: 0.853, val: 0.989 | iter time: 1345.64 ms\n",
            "Epoch 1 | iter 949 step 118 | loss train: 0.875, val: 0.989 | iter time: 2663.54 ms\n",
            "Epoch 1 | iter 950 step 118 | loss train: 0.885, val: 0.989 | iter time: 1330.23 ms\n",
            "\n",
            "| ------------------------------------------------------\n",
            "| Token Counts\n",
            "| - Input Tokens              :  277413\n",
            "| - Tokens w/ Prompt          :  353155\n",
            "| - Total Tokens (w/ Padding) :  480188\n",
            "| -----------------------------------------------------\n",
            "| Performance\n",
            "| - Training Time             :  2681.59 s\n",
            "| - Tok/sec                   :  179.07 tok/s\n",
            "| -----------------------------------------------------\n",
            "| Memory Usage                                                                 \n",
            "| - Memory Used               :  13.10 GB                                        \n",
            "-------------------------------------------------------\n",
            "\n",
            "Validating ...\n",
            "Final evaluation | val loss: 0.982 | val ppl: 2.670\n",
            "Saving LoRA weights to 'out/lit-finetuned/gemma-2-alpaca-it/final/lit_model.pth.lora'\n",
            "{'checkpoint_dir': PosixPath('out/lit-finetuned/gemma-2-alpaca-it/final'),\n",
            " 'precision': None,\n",
            " 'pretrained_checkpoint_dir': None}\n",
            "/usr/local/lib/python3.10/dist-packages/litgpt/scripts/merge_lora.py:67: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
            "  pretrained_checkpoint = torch.load(str(pretrained_checkpoint_dir / \"lit_model.pth\"), mmap=True)\n",
            "/usr/local/lib/python3.10/dist-packages/litgpt/scripts/merge_lora.py:68: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
            "  lora_checkpoint = torch.load(str(lora_path), mmap=True)\n",
            "Saved merged weights to 'out/lit-finetuned/gemma-2-alpaca-it/final/lit_model.pth'\n"
          ]
        }
      ],
      "source": [
        "os.environ[\"FINETUNED_MODEL_DIR\"] = \"out/lit-finetuned/gemma-2-alpaca-it\"\n",
        "\n",
        "!litgpt finetune_lora google/gemma-2-2b \\\n",
        "  --data Alpaca2k \\\n",
        "  --train.max_seq_length 512 \\\n",
        "  --train.micro_batch_size 2 \\\n",
        "  --train.epochs 1 \\\n",
        "  --out_dir $FINETUNED_MODEL_DIR \\\n",
        "  --precision bf16-true"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eO_KNp0yP0Z3"
      },
      "source": [
        "LitGPT's Python API supports pre-training and fine-tuning using the [PyTorch Lightning Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html). You can read more about this in the [LitGPT Python API tutorial](https://github.com/Lightning-AI/litgpt/blob/main/tutorials/python-api.md#pytorch-lightning-trainer-support).\n",
        "\n",
        "To customize fine-tuning further you can refer to the [LitGPT custom fine-tuning documentation](https://lightning.ai/lightning-ai/studios/litgpt-quick-start?section=featured#custom-finetuning)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iMqsKsg2JC_7"
      },
      "source": [
        "## 2. Prompt the fine-tuned model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "blFwWURCJKAJ"
      },
      "source": [
        "Next, you will test the fine-tuned model using the LitGPT command-line interface.\n",
        "\n",
        "Use `litgpt generate` to prompt the fine-tuned Gemma 2 model. Specify the path to the fine-tuned model checkpoint in the command.\n",
        "\n",
        "Use the `--prompt` argument to specify the query you want the model to answer.\n",
        "\n",
        "You can also specify your preferred values for parameters like `top_k`, `top_p`, `temperature`, `max_new_tokens` etc.\n",
        "\n",
        "Run `litgpt generate --help` to see all configurable parameters.\n",
        "Please refer to the [LitGPT inference tutorial](https://github.com/Lightning-AI/litgpt/blob/main/tutorials/inference.md) for more details."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "FFFMyb6oN76n"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "{'checkpoint_dir': PosixPath('out/lit-finetuned/gemma-2-alpaca-it/final'),\n",
            " 'compile': False,\n",
            " 'max_new_tokens': 50,\n",
            " 'num_samples': 1,\n",
            " 'precision': None,\n",
            " 'prompt': 'Generate the next number in the Fibonnaci series: 1, 1, 2, 3, 5',\n",
            " 'quantize': None,\n",
            " 'temperature': 0.8,\n",
            " 'top_k': 50,\n",
            " 'top_p': 1.0}\n",
            "Loading model 'out/lit-finetuned/gemma-2-alpaca-it/final/lit_model.pth' with {'name': 'Gemma-2-2b', 'hf_config': {'name': 'gemma-2-2b', 'org': 'google'}, 'scale_embeddings': True, 'attention_scores_scalar': 256, 'block_size': 8192, 'sliding_window_size': 4096, 'sliding_window_layer_placing': 2, 'vocab_size': 256000, 'padding_multiple': 512, 'padded_vocab_size': 256000, 'n_layer': 26, 'n_head': 8, 'head_size': 256, 'n_embd': 2304, 'rotary_percentage': 1.0, 'parallel_residual': False, 'bias': False, 'lm_head_bias': False, 'n_query_groups': 4, 'shared_attention_norm': False, 'norm_class_name': 'RMSNorm', 'post_attention_norm': True, 'post_mlp_norm': True, 'norm_eps': 1e-05, 'mlp_class_name': 'GemmaMLP', 'gelu_approximate': 'tanh', 'intermediate_size': 9216, 'rope_condense_ratio': 1, 'rope_base': 10000, 'rope_adjustments': None, 'n_expert': 0, 'n_expert_per_token': 0, 'attention_logit_softcapping': 50.0, 'final_logit_softcapping': 30.0, 'rope_n_elem': 256}\n",
            "Time to instantiate model: 0.14 seconds.\n",
            "Time to load the model weights: 5.07 seconds.\n",
            "Seed set to 1234\n",
            "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
            "\n",
            "### Instruction:\n",
            "Generate the next number in the Fibonnaci series: 1, 1, 2, 3, 5\n",
            "\n",
            "### Response:\n",
            "The next number in the Fibonacci series is 8.\n",
            "Time for inference 1: 1.57 sec total, 7.65 tokens/sec\n",
            "Memory used: 6.52 GB\n"
          ]
        }
      ],
      "source": [
        "!litgpt generate $FINETUNED_MODEL_DIR/final \\\n",
        "  --prompt \"Generate the next number in the Fibonacci series: 1, 1, 2, 3, 5\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6_KI4g7rMxbK"
      },
      "source": [
        "## 3. Serve the fine-tuned model\n",
        "You will now serve the fine-tuned model using LitGPT.\n",
        "\n",
        "To deploy your fine-tuned model, use the `litgpt serve` command. Specify the port number using the `--port` argument.\n",
        "\n",
        "When running in a Colab environment, you'll need to manage the LitGPT inference server as a Python subprocess using the `subprocess` package.\n",
        "\n",
        "For more details on deploying LLMs with LitGPT, refer to the [LitGPT Serve and Deploy LLMs tutorial](https://github.com/Lightning-AI/litgpt/blob/main/tutorials/deploy.md).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "OdTnZZ_zN2ID"
      },
      "outputs": [],
      "source": [
        "import subprocess\n",
        "import time\n",
        "\n",
        "command = [\n",
        "    \"litgpt\", \"serve\", \"out/lit-finetuned/gemma-2-alpaca-it/final\", \"--port\", \"30000\"\n",
        "]\n",
        "\n",
        "# Create a file to write logs\n",
        "with open(\"litgpt_serve.log\", \"w\") as logfile:\n",
        "  # Use subprocess.Popen to run the command with nohup-like behavior\n",
        "  server_process = subprocess.Popen(\n",
        "    command,\n",
        "    stdout=logfile,\n",
        "    stderr=subprocess.STDOUT,\n",
        "    stdin=subprocess.PIPE,\n",
        "    start_new_session=True  # This is similar to nohup behavior, detaches from terminal\n",
        "  )\n",
        "\n",
        "  # Send an Enter key (\\n) to the process to accept the terms\n",
        "  server_process.stdin.write(b'\\n')\n",
        "  server_process.stdin.flush()\n",
        "\n",
        "# Sleep for 60 seconds\n",
        "time.sleep(60)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rj7XEYrROenw"
      },
      "source": [
        "The server is now ready and can be reached at http://localhost:30000/ from within this notebook."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iTeStr5bPECg"
      },
      "source": [
        "### Query the inference server\n",
        "\n",
        "You can prompt the fine-tuned Gemma 2 model deployed via the inference server using Python's `requests` library.\n",
        "\n",
        "You can craft your prompt to adhere to the format of the samples in the Alpaca dataset. Import `prompts` from `litgpt` and instantiate the Alpaca prompt style. Use this prompt style to convert any prompt to the Alpaca prompt format as demonstrated below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "fSWlHqUZbfWG"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n",
            "\n",
            "### Instruction:\n",
            "Generate the next number in the Fibonnaci series.\n",
            "\n",
            "### Input:\n",
            "1, 1, 2, 3, 5, 8\n",
            "\n",
            "### Response:\n",
            "\n"
          ]
        }
      ],
      "source": [
        "from litgpt import prompts\n",
        "\n",
        "alpaca_prompt_style = prompts.Alpaca()\n",
        "prompt_text = alpaca_prompt_style.apply(prompt=\"Generate the next number in the Fibonacci series.\",\n",
        "                                        input=\"1, 1, 2, 3, 5, 8\")\n",
        "print(prompt_text)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SKQ8X6wx3_OU"
      },
      "source": [
        "Use Python's `requests` library to send a prediction request with your prompt to the inference server."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "YOTV170FPkJg"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "{\n",
            "  \"output\": \"The next number in the Fibonacci series is 8 as follows:\\n\\n1, 1, 2, 3, 5, 8, 13, 21 and 34, 55, 89\"\n",
            "}\n"
          ]
        }
      ],
      "source": [
        "import requests, json\n",
        "\n",
        "server_url = \" http://localhost:30000/\"\n",
        "\n",
        "response = requests.post(\n",
        "    server_url + \"predict\",\n",
        "    json={\n",
        "      \"prompt\": prompt_text,\n",
        "    }\n",
        ")\n",
        "\n",
        "print(json.dumps(response.json(), indent=2))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I4cH7Gji4Q0c"
      },
      "source": [
        "You can stop the inference server by killing the server process you started earlier."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "hIVmbvP8b7v-"
      },
      "outputs": [],
      "source": [
        "server_process.kill()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n4qZPKd-oAtK"
      },
      "source": [
        "## 4. Push fine-tuned model to Hugging Face Hub\n",
        "To push the fine-tuned model to the Hugging Face Hub, it must be converted to a format compatible with Hugging Face Transformers.\n",
        "\n",
        "Since the model was fine-tuned using the LoRA technique, you must first run the following command on the fine-tuned model directory:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "yhxSzOdGi4e5"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "{'checkpoint_dir': PosixPath('out/lit-finetuned/gemma-2-alpaca-it/final'),\n",
            " 'precision': None,\n",
            " 'pretrained_checkpoint_dir': None}\n",
            "LoRA weights have already been merged in this checkpoint.\n"
          ]
        }
      ],
      "source": [
        "!litgpt merge_lora $FINETUNED_MODEL_DIR/final"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "waMZ__UbpX7S"
      },
      "source": [
        "Next, convert the fine-tuned model to Hugging Face format using the following command. Specify the path to the LitGPT model to be converted and the desired output directory as arguments.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "hlDgyYakjSdt"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "{'checkpoint_dir': PosixPath('out/lit-finetuned/gemma-2-alpaca-it/final'),\n",
            " 'output_dir': PosixPath('out/hf-format/gemma2-finetuned-it')}\n"
          ]
        }
      ],
      "source": [
        "# Set an environment variable for the output directory where the model must be\n",
        "# saved after conversion to Hugging Face compatible format.\n",
        "# This will be used later to push the model to the hub.\n",
        "os.environ[\"FINETUNED_HF_MODEL\"] = \"out/hf-format/gemma2-finetuned-it\"\n",
        "\n",
        "!litgpt convert_from_litgpt $FINETUNED_MODEL_DIR/final/ $FINETUNED_HF_MODEL"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fO7JBimUqNTS"
      },
      "source": [
        "To push the model to Hugging Face Hub, you must load the fine-tuned model weights into a `Transformers` model.\n",
        "\n",
        "First, a state dictionary is loaded from the fine-tuned model weights. Then, an instance of `AutoModel` is created from the configuration of the pre-trained `gemma-2-2b` model, loaded with the weights of the fine-tuned model.\n",
        "\n",
        "To learn more about converting LitGPT models to Hugging Face Transformers, refer to the [converting LitGPT weights to Hugging Face Transformers tutorial](https://github.com/Lightning-AI/litgpt/blob/7449dad90740c4b0947a6ccb474b869ef969e110/tutorials/convert_lit_models.md).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "SpmRPLg0jbv4"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "2.4.1+cu121\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "<ipython-input-1-a94d200608df>:5: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
            "  state_dict = torch.load(\"out/hf-format/gemma2-finetuned-it/model.pth\")\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "5f26379cbf694fbb9b04c54db72968a2",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "974621360700407d8cc87e51c6ea2b66",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b0d6441a1c874f39998436127320ad12",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "79d3c4ff2ea54ff483afbccdf5316ae8",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "01bde09d40ad445bb0cf116f76a4d66c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00003-of-00003.safetensors:   0%|          | 0.00/481M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "import torch\n",
        "from transformers import AutoModel\n",
        "\n",
        "state_dict = torch.load(\"out/hf-format/gemma2-finetuned-it/model.pth\")\n",
        "model = AutoModel.from_pretrained(\"google/gemma-2-2b\", state_dict=state_dict)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6I8bNMXr6nTu"
      },
      "source": [
        "Use the model's `push_to_hub()` method to upload the model to Hugging Face Hub.\n",
        "\n",
        "**Notes**:\n",
        "1. In the following code snippet, replace \"your_hf_username\" with your Hugging Face username.\n",
        "2. Your Hugging Face token needs to have write permissions.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "NwLUgcjD5h6l"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "048b78a1ea3042a28ba196958696b789",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "No files have been modified since last commit. Skipping to prevent empty commit.\n",
            "WARNING:huggingface_hub.hf_api:No files have been modified since last commit. Skipping to prevent empty commit.\n"
          ]
        },
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "CommitInfo(commit_url='https://huggingface.co/prianka-kariat/gemma_finetuned_alpaca_it/commit/445cdc9120aacc566cd9582249095f94b64a98dd', commit_message='Upload model', commit_description='', oid='445cdc9120aacc566cd9582249095f94b64a98dd', pr_url=None, repo_url=RepoUrl('https://huggingface.co/prianka-kariat/gemma_finetuned_alpaca_it', endpoint='https://huggingface.co', repo_type='model', repo_id='prianka-kariat/gemma_finetuned_alpaca_it'), pr_revision=None, pr_num=None)"
            ]
          },
          "execution_count": 3,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "model.push_to_hub(\"your_hf_username/gemma_finetuned_alpaca_it\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vlRAC8OFa1pQ"
      },
      "source": [
        "Congratulations! You've successfully fine-tuned, run, and served Gemma 2 using LitGPT.\n",
        "\n",
        "What's next?\n",
        "\n",
        "Your next steps could include the following:\n",
        "\n",
        "**Experiment with different datasets**: Try fine-tuning on other instruction-tuning datasets supported by LitGPT. Implement custom workflows to fine-tune the model on other datasets from Hugging Face Hub or your data to adapt the model to various tasks or domains.\n",
        "\n",
        "**Tune hyperparameters**: Adjust training parameters (e.g., learning rate, batch size, epochs, LoRA settings) to optimize performance and improve training efficiency.\n",
        "\n",
        "**Evaluate the fine-tuned model**: Evaluate the fine-tuned model using LitGPT's command-line interface.\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "[Gemma_2]Finetune_with_LitGPT.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
